diff --git a/.agent/skills/translation/SKILL.md b/.agent/skills/translation/SKILL.md new file mode 100644 index 00000000..fcc3e2ca --- /dev/null +++ b/.agent/skills/translation/SKILL.md @@ -0,0 +1,137 @@ +--- +name: Project General Translation & Terminology Guidelines +description: Definitive guidelines, contextual awareness strategies, standard terminology, and comment formatting rules for translating code, configurations, and documentation from Chinese to English in this repository. +--- + +# 🤖 Systemic Translation & Terminology Instructions for AI Agents + +This document is the absolute source of truth and **Standard Operating Procedure (SOP)** for translating Chinese comments, configurations, and documentation into English within this repository. + +**ATTENTION AI AGENTS:** You are NOT merely translating words; you are executing a systematic algorithm to localize complex streaming media and networking concepts. Do not rely solely on "passive reading" or "translation memory." You MUST follow the rigid workflow outlined below. + +--- + +## Phase 1: Contextual Anchoring (MANDATORY BEFORE TRANSLATION) + +Before translating any block of text, you must explicitly anchor yourself to the specific technical domain. **Literal translation of Chinese industry slang (黑话) is strictly prohibited.** + +1. **Identify the Domain:** Look at the module or configuration section (e.g., `[rtp_proxy]`, `[http]`, `[general]`, `[hls]`). +2. **Setup the Mental Lexicon:** + - If `[api]/[http]`: Anchor to standard REST API and Web server concepts (e.g., `Requests/Responses`, `CORS`, `Forwarded IPs`). + - If `Network I/O / [general]`: Anchor to socket programming and OS-level terms (e.g., `Write coalescing`, `Buffers`, `File handles`). + - If `Media Streaming (RTSP/RTMP/RTC)`: Anchor to multimedia transport concepts (e.g., `GOP`, `Payload`, `B-frames`, `Jitter`, `Visual artifacts`). +3. **Verification-Driven Translation:** If you encounter a Chinese term that sounds colloquial or metaphoric (e.g., “花屏” - flowered screen, “秒开” - open in seconds, “溯源” - trace back to origin), **DO NOT guess or translate literally**. Ask yourself: _"How do top-tier English open-source projects (FFmpeg, WebRTC, Nginx) refer to this specific technical phenomenon?"_ + +--- + +## Phase 2: Structural Translation & Anti-Pattern Detection + +LLMs naturally tend to follow the grammatical structure of the source text. Chinese technical writing often uses sprawling sentences and explanatory fillers. You must actively break these patterns. + +### 🚫 Rule 1: The "Action-Result" Paradigm + +- **Trigger:** When the Chinese text says "设置为0关闭此特性" (Setting this to 0 disables this feature) or "打开此选项会导致..." (Turning this on causes...). +- **Execution:** Force your output to use the exact structure: `Setting this to [Value] disables [Feature] and allows [Consequence].` Do NOT translate explanatory filler like "This mechanism's logic dictates that...". + +### 🚫 Rule 2: Sub-clause Elimination (No "Chinglish") + +- **Trigger:** Long noun clusters or overly personified system descriptions (e.g., "服务器会认为这个流是断开的" - The server will think this stream is disconnected). +- **Execution:** Use direct, objective voice: `The stream is considered disconnected.` or `The system drops the stream.` + +### 🚫 Rule 3: Clarifying Ambiguous Actions + +- **Trigger:** The word `忽略` (Ignore/Skip) vs. `丢弃/放弃` (Abandon/Drop). +- **Execution:** Use `Ignore` or `Skip` for non-critical timeouts (e.g., waiting for a track to be ready). Reserve `Abandon`, `Drop` or `Disconnect` only for fatal errors or closed sockets. + +### 🚫 Rule 4: Zero Information Loss & Causal Reconstruction + +- **Trigger:** When condensing text for native flow, or translating complex caveats (e.g., parenthetical conditions, "而不是" / instead of, side-effects). +- **Execution:** You may reorganise syntax to sound professional, but you MUST NOT drop crucial qualifiers, modifiers, or side effects. If a Chinese config says "instead of returning X via hook", the English translation must explicitly mention "returning X". Information completeness supersedes structural brevity. + +### 🚫 Rule 5: The Golden Balance (Zero Info Loss vs. Native Phrasing) + +- **The Core Conflict:** You must achieve **Zero Information Loss** WITHOUT resorting to **Chinglish** or literal word-for-word translation. +- **What "Information" Means:** "Retaining information" means capturing 100% of the **technical causality**, **side-effects**, **prerequisites**, and **system boundaries** present in the Chinese text. +- **What "Information" DOES NOT Mean:** It does NOT mean preserving the Chinese grammatical structure, literal phrasing, or colloquialisms (啰嗦句子和字面用词). +- **Execution (The Top-Down Conceptual Approach):** + 1. **Contextual Override:** Never translate a noun literally if the surrounding constraints (e.g., units like "seconds", prefixes, or the specific protocol) dictate a domain term. For example, if a setting is measured in "seconds", the Chinese word "大小" (size) MUST logically translate to `Duration` or `Interval`, **NEVER** `Size`. + 2. **Conceptual Compression:** When faced with a sprawling, explanatory Chinese sentence, distill the _technical payload_ and express it using concise, standard industry jargon. + - _Anti-pattern (Literal/Chinglish):_ `After disabling the traditional authentication mode, you must first call the API to log in. Upon success, a cookie will be set, and all APIs can be accessed without restriction as long as the cookie is valid.` + - _Pro-pattern (Native/Jargon):_ `When disabled, users must first call /index/api/login. Upon success, a cookie auth token is set for subsequent requests.` (Using "subsequent requests" efficiently compresses the lengthy Chinese explanation). + 3. **Technical Abstraction:** Recognize standard operations (e.g., "拉流再推流"). Do not translate the physical actions (`pulling and then pushing`); translate the abstract technical process (`re-publishing` or `re-encoding`). + +### 🚫 Rule 6: Anti-Summarization (Strict Boolean & Causality Preservation) + +- **Trigger:** When applying Conceptual Compression (Rule 5) to a text block containing conditionals or explanations. +- **The Core Conflict:** _Compression_ reduces word count by using jargon. _Summarization_ drops critical logic. **Summarization is strictly forbidden.** +- **Execution (The Boolean Mapping Rule):** + 1. Treat Chinese comments like code blocks. Extract all `IF/THEN/ELSE` branches, prerequisites, and root causes before translating. + 2. If the original text states a "success" path and a "failure" path, the English translation MUST explicitly state both paths. You cannot compress them into a single vague outcome. + 3. If the original text states _why_ a feature exists (the exact cause or defect being prevented), the English translation MUST explicitly state that exact cause. You cannot compress it into generic "to improve performance" or "to prevent errors." + 4. Perform a **Reverse Mapping Check**: After writing the English sentence, ask yourself—"If I reverse-compile this English back to Chinese, would any `IF` conditions or edge-case explanations be missing?" If yes, rewrite it completely to restore the dropped logic. + +--- + +## Phase 3: The Hardcoded Terminology Dictionary + +**CRITICAL:** When translating, if you encounter these Chinese concepts, you MUST use the exact, first provided English term. **Do not mix or alternate synonyms.** + +### Network & Architecture + +- 源站 -> `Origin server` +- 溯源 (拉流) -> `Origin pull` +- 推流代理 / 拉流代理 -> `Publishing proxies` / `Pulling proxies` +- 按需拉流 -> `On-demand stream pulling` +- 集群 -> `Cluster` +- 推流断开后的超时等待 -> `Grace period for publisher reconnection` + +### Video & Playback Experience + +- 秒开 / 极速秒开 -> `Instant playback (zero-delay startup)` (e.g., 级联秒开 -> `Instant playback for cascaded streams`) +- 花屏 -> `Visual artifacts (glitches)` _(NEVER use "Screen tearing", which is a hardware V-sync issue)_ +- 卡顿 -> `Playback stuttering` + +### System I/O & HTTP + +- 合并写 -> `Write coalescing` _(NEVER use "Merged write")_ +- 请求和回复 -> `Requests and Responses` _(Avoid "Replies")_ +- 在代理后方获取真实IP -> `Extract the real client IP when behind a proxy (e.g., via X-Forwarded-For)` + +### General Technical Terms + +- 切片 -> `Segment` (e.g., HLS segment) +- 封装 / 打包 -> `Packaging` +- 负载 -> `Payload` +- 鉴权 -> `Authentication` +- 处理 / 应对 (故障) -> `Handle` or `Address` + +--- + +## Phase 4: Strict Formatting Rules (CRITICAL) + +When translating comments inside code files (`.cpp`, `.h`) or configs (`.ini`), apply these hard constraints: + +1. **Bilingual Retention:** Unless explicitly instructed to delete Chinese, **ALWAYS retain the original Chinese comments**. +2. **Bottom Placement:** Place the English translation immediately **below** the Chinese line or block. +3. **Block Uniformity:** Do NOT translate line-by-line (`ZH-EN-ZH-EN`). If a Chinese comment is a 3-line block, output it as a 3-line Chinese block followed by a 3-line English block. + +```cpp +/* + * 这里是第一行中文描述。 + * 这里是第二行中文补充。 + */ +/* + * This is the English translation of the first line. + * This is the English translation of the second line. + */ +``` + +--- + +## Phase 5: The Post-Translation Verification Workflow (DO NOT SKIP) + +If you are asked to review or update translations in a long file, **you cannot rely solely on passive reading**. You MUST execute this workflow: + +1. **Active Scan (Regex/Search):** Before reading the document, use file search tools to actively scan for known anti-patterns in the current English text (e.g., search for `Screen tearing`, `Merged write`, `Replies`, `Source station`). Fix them immediately. +2. **Format Review:** Scan for `ZH-EN-ZH-EN` interleaving and fix it to block format. +3. **Blind English Review:** After translating, hide the Chinese text from your mental context. Read _only_ your English output constraint: _Does this sound like a snippet from the official Nginx or WebRTC manuals? Is it concise (CBD: Clarity, Brevity, Directness)?_ If it sounds like a literal word-for-word translation, rewrite it natively. diff --git a/.claude/skills b/.claude/skills new file mode 120000 index 00000000..9b058317 --- /dev/null +++ b/.claude/skills @@ -0,0 +1 @@ +../.agent/skills \ No newline at end of file diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 83b13007..12b196d9 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -2,7 +2,7 @@ name: Android on: [push, pull_request] jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - name: 下载源码 diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 75f23ee6..f6bd2795 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -5,7 +5,7 @@ on: [push, pull_request] jobs: analyze: name: Analyze - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 permissions: actions: read contents: read @@ -43,7 +43,7 @@ jobs: with: repository: cisco/libsrtp fetch-depth: 1 - ref: v2.3.0 + ref: v2.7.0 path: 3rdpart/libsrtp - name: 编译 SRTP diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 5f2b007f..e13fd5fe 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -15,7 +15,7 @@ env: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 permissions: contents: read packages: write diff --git a/.github/workflows/issue_lint.yml b/.github/workflows/issue_lint.yml index c33e4f72..b4e4e560 100644 --- a/.github/workflows/issue_lint.yml +++ b/.github/workflows/issue_lint.yml @@ -6,7 +6,7 @@ on: jobs: issue_lint: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index a85a748d..a1ffd256 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -5,7 +5,7 @@ on: [push, pull_request] jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/linux_py.yml b/.github/workflows/linux_py.yml new file mode 100644 index 00000000..86acdb54 --- /dev/null +++ b/.github/workflows/linux_py.yml @@ -0,0 +1,172 @@ +name: Linux_Python + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v1 + + - name: 下载submodule源码 + run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init + + - name: 下载 SRTP + uses: actions/checkout@v2 + with: + repository: cisco/libsrtp + fetch-depth: 1 + ref: v2.3.0 + path: 3rdpart/libsrtp + + - name: 下载 openssl + uses: actions/checkout@v2 + with: + repository: openssl/openssl + fetch-depth: 1 + ref: OpenSSL_1_1_1 + path: 3rdpart/openssl + + - name: 下载 usrsctp + uses: actions/checkout@v2 + with: + repository: sctplab/usrsctp + fetch-depth: 1 + ref: 0.9.5.0 + path: 3rdpart/usrsctp + + - name: 启动 Docker 容器, 在Docker 容器中执行脚本 + run: | + docker pull centos:7 + docker run -v $(pwd):/root -w /root --rm centos:7 sh -c " + #!/bin/bash + set -x + + # Backup original CentOS-Base.repo file + cp /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup + + # Define new repository configuration + cat < /etc/yum.repos.d/CentOS-Base.repo + [base] + name=CentOS-7 - Base - mirrors.aliyun.com + baseurl=http://mirrors.aliyun.com/centos/7/os/x86_64/ + gpgcheck=1 + gpgkey=http://mirrors.aliyun.com/centos/RPM-GPG-KEY-CentOS-7 + + [updates] + name=CentOS-7 - Updates - mirrors.aliyun.com + baseurl=http://mirrors.aliyun.com/centos/7/updates/x86_64/ + gpgcheck=1 + gpgkey=http://mirrors.aliyun.com/centos/RPM-GPG-KEY-CentOS-7 + EOF + cat > /etc/yum.repos.d/epel-aliyun.repo < /etc/yum.repos.d/CentOS-SCLo-aliyun.repo <" "_")" >> $GITHUB_ENV + echo "BRANCH2=$(echo ${GITHUB_REF#refs/heads/} )" >> $GITHUB_ENV + echo "DATE=$(date +%Y-%m-%d)" >> $GITHUB_ENV + + - name: 打包二进制 + id: upload + uses: actions/upload-artifact@v4 + with: + name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }} + path: release/* + if-no-files-found: error + retention-days: 90 + + - name: issue评论 + if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + github.rest.issues.createComment({ + issue_number: ${{vars.VERSION_ISSUE_NO}}, + owner: context.repo.owner, + repo: context.repo.repo, + body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n' + + '- 分支: ${{ env.BRANCH2 }}\n' + + '- git hash: ${{ github.sha }} \n' + + '- 编译日期: ${{ env.DATE }}\n' + + '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n' + + '- 打包ci名: ${{ github.workflow }}\n' + + '- 开启特性: openssl/webrtc/datachannel\n' + + '- 说明: 本二进制在centos7(x64)上编译,请确保您的机器系统不低于此版本;本程序依赖python3.11, 运行前请miniconda安装python3.11\n' + }) diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 86758e77..4914a54f 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -18,10 +18,15 @@ jobs: with: vcpkgDirectory: '${{github.workspace}}/vcpkg' vcpkgTriplet: arm64-osx - # 2024.06.01 - vcpkgGitCommitId: '47364fbc300756f64f7876b549d9422d5f3ec0d3' + # 2025.07.11 + vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729' vcpkgArguments: 'openssl libsrtp[openssl] usrsctp' + - name: 安装指定 CMake + uses: jwlawson/actions-setup-cmake@v2 + with: + cmake-version: '3.30.5' + - name: 编译 uses: lukka/run-cmake@v3 with: diff --git a/.github/workflows/macos_py.yml b/.github/workflows/macos_py.yml new file mode 100644 index 00000000..2c29a076 --- /dev/null +++ b/.github/workflows/macos_py.yml @@ -0,0 +1,80 @@ +name: macOS_Python + +on: [push, pull_request] + +jobs: + build: + + runs-on: macOS-latest + + steps: + - uses: actions/checkout@v1 + + - name: 下载submodule源码 + run: mv -f .gitmodules_github .gitmodules && git submodule sync && git submodule update --init + + - name: 配置 vcpkg + uses: lukka/run-vcpkg@v7 + with: + vcpkgDirectory: '${{github.workspace}}/vcpkg' + vcpkgTriplet: arm64-osx + # 2025.07.11 + vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729' + vcpkgArguments: 'openssl libsrtp[openssl] usrsctp' + + - name: 安装指定 CMake + uses: jwlawson/actions-setup-cmake@v2 + with: + cmake-version: '3.30.5' + + - name: 检查并设置 Python 3 + run: | + PYTHON_ROOT=$(python3 -c "import sys; print(sys.prefix)") + echo "PYTHON_ROOT=$PYTHON_ROOT" >> $GITHUB_ENV + PYTHON_EXECUTABLE=$(which python3) + echo "PYTHON_EXECUTABLE=$PYTHON_EXECUTABLE" >> $GITHUB_ENV + + - name: 编译 + uses: lukka/run-cmake@v3 + with: + useVcpkgToolchainFile: true + cmakeBuildType: Release + cmakeListsOrSettingsJson: CMakeListsTxtAdvanced + buildDirectory: '${{github.workspace}}/build' + buildWithCMakeArgs: '--config Release' + cmakeAppendedArgs: '-DPYTHON_EXECUTABLE=${{ env.PYTHON_EXECUTABLE }} -DENABLE_PYTHON=ON -DENABLE_API=OFF -DENABLE_TESTS=OFF -DCMAKE_BUILD_TYPE=Release' + + - name: 设置环境变量 + run: | + echo "BRANCH=$(echo ${GITHUB_REF#refs/heads/} | tr -s "/\?%*:|\"<>" "_")" >> $GITHUB_ENV + echo "BRANCH2=$(echo ${GITHUB_REF#refs/heads/} )" >> $GITHUB_ENV + echo "DATE=$(date +%Y-%m-%d)" >> $GITHUB_ENV + + - name: 打包二进制 + id: upload + uses: actions/upload-artifact@v4 + with: + name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }} + path: release/* + if-no-files-found: error + retention-days: 90 + + - name: issue评论 + if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + github.rest.issues.createComment({ + issue_number: ${{vars.VERSION_ISSUE_NO}}, + owner: context.repo.owner, + repo: context.repo.repo, + body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n' + + '- 分支: ${{ env.BRANCH2 }}\n' + + '- git hash: ${{ github.sha }} \n' + + '- 编译日期: ${{ env.DATE }}\n' + + '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n' + + '- 打包ci名: ${{ github.workflow }}\n' + + '- 开启特性: openssl/webrtc/datachannel\n' + + '- 说明: 此二进制为arm64版本; 本程序依赖python3.14, 运行前请brew install python@3.14安装\n' + }) \ No newline at end of file diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 03acf75c..dc85e911 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -4,7 +4,7 @@ on: [pull_request] jobs: check: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v2 with: diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 4ebe3ccc..e092e965 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -4,7 +4,7 @@ on: [push, pull_request] jobs: build: - runs-on: windows-2019 + runs-on: windows-2022 steps: - uses: actions/checkout@v1 @@ -17,8 +17,8 @@ jobs: with: vcpkgDirectory: '${{github.workspace}}/vcpkg' vcpkgTriplet: x64-windows-static - # 2024.06.01 - vcpkgGitCommitId: '47364fbc300756f64f7876b549d9422d5f3ec0d3' + # 2025.07.11 + vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729' vcpkgArguments: 'openssl libsrtp[openssl] usrsctp' - name: 编译 diff --git a/.github/workflows/windows_py.yml b/.github/workflows/windows_py.yml new file mode 100644 index 00000000..678177d3 --- /dev/null +++ b/.github/workflows/windows_py.yml @@ -0,0 +1,86 @@ +name: Windows_Python + +on: [push, pull_request] + +jobs: + build: + runs-on: windows-2022 + + steps: + - uses: actions/checkout@v1 + + - name: 下载submodule源码 + run: mv -Force .gitmodules_github .gitmodules && git submodule sync && git submodule update --init + + - name: 配置 vcpkg + uses: lukka/run-vcpkg@v7 + with: + vcpkgDirectory: '${{github.workspace}}/vcpkg' + vcpkgTriplet: x64-windows-static + # 2025.07.11 + vcpkgGitCommitId: 'efcfaaf60d7ec57a159fc3110403d939bfb69729' + vcpkgArguments: 'openssl libsrtp[openssl] usrsctp' + + - name: Setup Python 3.14 + uses: actions/setup-python@v4 + with: + python-version: 3.14 + architecture: x64 + + - name: Set PYTHON_EXECUTABLE + shell: pwsh + run: | + $pythonExe = python -c "import sys; print(sys.executable)" + Add-Content -Path $Env:GITHUB_ENV -Value "PYTHON_EXECUTABLE=$pythonExe" + + - name: Check PYTHON_EXECUTABLE + run: echo $Env:PYTHON_EXECUTABLE + shell: pwsh + + - name: 编译 + uses: lukka/run-cmake@v3 + with: + useVcpkgToolchainFile: true + cmakeBuildType: Release + cmakeListsOrSettingsJson: CMakeListsTxtAdvanced + buildDirectory: '${{github.workspace}}/build' + buildWithCMakeArgs: '--config Release' + cmakeAppendedArgs: '-DPYTHON_EXECUTABLE=${{ env.PYTHON_EXECUTABLE }} -DENABLE_PYTHON=ON -DENABLE_API=OFF -DENABLE_TESTS=OFF -DCMAKE_BUILD_TYPE=Release' + + - name: 设置环境变量 + run: | + $dateString = Get-Date -Format "yyyy-MM-dd" + $branch = $env:GITHUB_REF -replace "refs/heads/", "" -replace "[\\/\\\?\%\*:\|\x22<>]", "_" + $branch2 = $env:GITHUB_REF -replace "refs/heads/", "" + echo "BRANCH=$branch" >> $env:GITHUB_ENV + echo "BRANCH2=$branch2" >> $env:GITHUB_ENV + echo "DATE=$dateString" >> $env:GITHUB_ENV + + - name: 打包二进制 + id: upload + uses: actions/upload-artifact@v4 + with: + name: ${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }} + path: release/* + if-no-files-found: error + retention-days: 90 + + - name: issue评论 + if: github.event_name != 'pull_request' && github.ref != 'refs/heads/feature/test' + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + github.rest.issues.createComment({ + issue_number: ${{vars.VERSION_ISSUE_NO}}, + owner: context.repo.owner, + repo: context.repo.repo, + body: '- 下载地址: [${{ github.workflow }}_${{ env.BRANCH }}_${{ env.DATE }}](${{ steps.upload.outputs.artifact-url }})\n' + + '- 分支: ${{ env.BRANCH2 }}\n' + + '- git hash: ${{ github.sha }} \n' + + '- 编译日期: ${{ env.DATE }}\n' + + '- 编译记录: [${{ github.run_id }}](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }})\n' + + '- 打包ci名: ${{ github.workflow }}\n' + + '- 开启特性: openssl/webrtc/datachannel\n' + + '- 说明: 此二进制为x64版本;本程序依赖python3.14, 运行前请先安装python3.14\n' + }) diff --git a/.gitmodules b/.gitmodules index c6211ba1..af38aadf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "www/webassist"] path = www/webassist url = https://gitee.com/victor1002/zlm_webassist +[submodule "3rdpart/pybind11"] + path = 3rdpart/pybind11 + url = https://gitee.com/mirrors/pybind11.git diff --git a/.gitmodules_github b/.gitmodules_github index 87b576ee..f3b18b57 100644 --- a/.gitmodules_github +++ b/.gitmodules_github @@ -9,4 +9,7 @@ url = https://github.com/open-source-parsers/jsoncpp.git [submodule "www/webassist"] path = www/webassist - url = https://github.com/1002victor/zlm_webassist \ No newline at end of file + url = https://github.com/1002victor/zlm_webassist +[submodule "3rdpart/pybind11"] + path = 3rdpart/pybind11 + url = https://github.com/pybind/pybind11.git \ No newline at end of file diff --git a/3rdpart/CMakeLists.txt b/3rdpart/CMakeLists.txt index f55fc6c7..8a5a534d 100644 --- a/3rdpart/CMakeLists.txt +++ b/3rdpart/CMakeLists.txt @@ -116,113 +116,18 @@ endif() ############################################################################## # toolkit -# TODO: 改造 toolkit 以便直接引用 - -include(CheckStructHasMember) -include(CheckSymbolExists) - -# 检查 sendmmsg 相关依赖并设置对应的宏, 配置 _GNU_SOURCE 以启用 GNU 扩展特性 -list(APPEND CMAKE_REQUIRED_DEFINITIONS -D_GNU_SOURCE) -check_struct_has_member("struct mmsghdr" msg_hdr sys/socket.h HAVE_MMSG_HDR) -check_symbol_exists(sendmmsg sys/socket.h HAVE_SENDMMSG_API) -check_symbol_exists(recvmmsg sys/socket.h HAVE_RECVMMSG_API) - -set(COMPILE_DEFINITIONS) -# ToolKit 依赖 ENABLE_OPENSSL 以及 ENABLE_MYSQL -list(FIND MK_COMPILE_DEFINITIONS ENABLE_OPENSSL ENABLE_OPENSSL_INDEX) -if(NOT ENABLE_OPENSSL_INDEX EQUAL -1) - list(APPEND COMPILE_DEFINITIONS ENABLE_OPENSSL) -endif() -list(FIND MK_COMPILE_DEFINITIONS ENABLE_MYSQL ENABLE_MYSQL_INDEX) -if(NOT ENABLE_MYSQL_INDEX EQUAL -1) - list(APPEND COMPILE_DEFINITIONS ENABLE_MYSQL) -endif() -if(HAVE_MMSG_HDR) - list(APPEND COMPILE_DEFINITIONS HAVE_MMSG_HDR) -endif() -if(HAVE_SENDMMSG_API) - list(APPEND COMPILE_DEFINITIONS HAVE_SENDMMSG_API) -endif() -if(HAVE_RECVMMSG_API) - list(APPEND COMPILE_DEFINITIONS HAVE_RECVMMSG_API) -endif() - -# check the socket buffer size set by the upper cmake project, if it is set, use the setting of the upper cmake project, otherwise set it to 256K -# if the socket buffer size is set to 0, it means that the socket buffer size is not set, and the kernel default value is used(just for linux) -if(DEFINED SOCKET_DEFAULT_BUF_SIZE) - if (SOCKET_DEFAULT_BUF_SIZE EQUAL 0) - message(STATUS "Socket default buffer size is not set, use the kernel default value") - else() - message(STATUS "Socket default buffer size is set to ${SOCKET_DEFAULT_BUF_SIZE}") - endif () - add_definitions(-DSOCKET_DEFAULT_BUF_SIZE=${SOCKET_DEFAULT_BUF_SIZE}) -endif() - -set(ToolKit_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/ZLToolKit) -# 收集源代码 -file(GLOB ToolKit_SRC_LIST - ${ToolKit_ROOT}/src/*/*.cpp - ${ToolKit_ROOT}/src/*/*.h - ${ToolKit_ROOT}/src/*/*.c) -if(IOS) - list(APPEND ToolKit_SRC_LIST - ${ToolKit_ROOT}/src/Network/Socket_ios.mm) -endif() - -################################################################### -#使用wepoll windows iocp 模拟 epoll -if(ENABLE_WEPOLL) - if(WIN32) - message(STATUS "Enable wepoll") - #增加wepoll源文件及api参数兼容文件 - list(APPEND ToolKit_SRC_LIST - ${CMAKE_CURRENT_SOURCE_DIR}/wepoll/wepoll.c - ${CMAKE_CURRENT_SOURCE_DIR}/wepoll/sys/epoll.cpp) - #增加wepoll头文件目录 - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/wepoll) - #开启epoll - add_definitions(-DHAS_EPOLL) - endif() -endif() -################################################################### - -# 去除 win32 的适配代码 -if(NOT WIN32) - list(REMOVE_ITEM ToolKit_SRC_LIST ${ToolKit_ROOT}/win32/getopt.c) -else() - # 防止 Windows.h 包含 Winsock.h - list(APPEND COMPILE_DEFINITIONS - WIN32_LEAN_AND_MEAN MP4V2_NO_STDINT_DEFS - # 禁用警告 - _CRT_SECURE_NO_WARNINGS _WINSOCK_DEPRECATED_NO_WARNINGS) -endif() - -# 添加库 -add_library(zltoolkit STATIC ${ToolKit_SRC_LIST}) -add_library(ZLMediaKit::ToolKit ALIAS zltoolkit) -target_compile_definitions(zltoolkit - PUBLIC ${COMPILE_DEFINITIONS}) -target_compile_options(zltoolkit - PRIVATE ${COMPILE_OPTIONS_DEFAULT}) -target_include_directories(zltoolkit - PRIVATE - "$" - PUBLIC - "$/src") - +add_subdirectory(ZLToolKit) +# 添加库别名 +add_library(ZLMediaKit::ToolKit ALIAS ZLToolKit) +# 添加依赖 update_cached_list(MK_LINK_LIBRARIES ZLMediaKit::ToolKit) -if(USE_SOLUTION_FOLDERS AND (NOT GROUP_BY_EXPLORER)) - # 在 IDE 中对文件进行分组, 源文件和头文件分开 - set_file_group(${ToolKit_ROOT}/src ${ToolKit_SRC_LIST}) -endif() +############################################################################## -# 未在使用 -if(ENABLE_CXX_API) - # 保留目录结构 - install(DIRECTORY ${ToolKit_ROOT}/ - DESTINATION ${INSTALL_PATH_INCLUDE}/ZLToolKit - REGEX "(.*[.](md|cpp)|win32)$" EXCLUDE) - install(TARGETS zltoolkit - DESTINATION ${INSTALL_PATH_LIB}) -endif() +if (ENABLE_PYTHON) + # ============ pybind11 lib ============ + add_subdirectory(pybind11) + update_cached_list(MK_LINK_LIBRARIES pybind11::embed) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/pybind11/include) + update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_PYTHON) +endif () \ No newline at end of file diff --git a/3rdpart/ZLToolKit b/3rdpart/ZLToolKit index 04212017..7302286c 160000 --- a/3rdpart/ZLToolKit +++ b/3rdpart/ZLToolKit @@ -1 +1 @@ -Subproject commit 04212017c0dc764f99f1db46240d59dcdf154700 +Subproject commit 7302286cf4be39d416b023fec3fd4ca9c54af762 diff --git a/3rdpart/jsoncpp b/3rdpart/jsoncpp index 69098a18..ca98c984 160000 --- a/3rdpart/jsoncpp +++ b/3rdpart/jsoncpp @@ -1 +1 @@ -Subproject commit 69098a18b9af0c47549d9a271c054d13ca92b006 +Subproject commit ca98c98457b1163cca1f7d8db62827c115fec6d1 diff --git a/3rdpart/media-server b/3rdpart/media-server index 0658496d..21c4451f 160000 --- a/3rdpart/media-server +++ b/3rdpart/media-server @@ -1 +1 @@ -Subproject commit 0658496d5fc7d238f41e10ea4d0a10113a8eed84 +Subproject commit 21c4451ff2e4c4bb1c817e606c8b4e5deac1e719 diff --git a/3rdpart/pybind11 b/3rdpart/pybind11 new file mode 160000 index 00000000..ed5057de --- /dev/null +++ b/3rdpart/pybind11 @@ -0,0 +1 @@ +Subproject commit ed5057ded698e305210269dafa57574ecf964483 diff --git a/3rdpart/wepoll/LICENSE b/3rdpart/wepoll/LICENSE deleted file mode 100644 index d7fc4b11..00000000 --- a/3rdpart/wepoll/LICENSE +++ /dev/null @@ -1,28 +0,0 @@ -wepoll - epoll for Windows -https://github.com/piscisaureus/wepoll - -Copyright 2012-2020, Bert Belder -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/3rdpart/wepoll/README.md b/3rdpart/wepoll/README.md deleted file mode 100644 index d334d083..00000000 --- a/3rdpart/wepoll/README.md +++ /dev/null @@ -1,202 +0,0 @@ -# wepoll - epoll for windows - -[![][ci status badge]][ci status link] - -This library implements the [epoll][man epoll] API for Windows -applications. It is fast and scalable, and it closely resembles the API -and behavior of Linux' epoll. - -## Rationale - -Unlike Linux, OS X, and many other operating systems, Windows doesn't -have a good API for receiving socket state notifications. It only -supports the `select` and `WSAPoll` APIs, but they -[don't scale][select scale] and suffer from -[other issues][wsapoll broken]. - -Using I/O completion ports isn't always practical when software is -designed to be cross-platform. Wepoll offers an alternative that is -much closer to a drop-in replacement for software that was designed -to run on Linux. - -## Features - -* Can poll 100000s of sockets efficiently. -* Fully thread-safe. -* Multiple threads can poll the same epoll port. -* Sockets can be added to multiple epoll sets. -* All epoll events (`EPOLLIN`, `EPOLLOUT`, `EPOLLPRI`, `EPOLLRDHUP`) - are supported. -* Level-triggered and one-shot (`EPOLLONESTHOT`) modes are supported -* Trivial to embed: you need [only two files][dist]. - -## Limitations - -* Only works with sockets. -* Edge-triggered (`EPOLLET`) mode isn't supported. - -## How to use - -The library is [distributed][dist] as a single source file -([wepoll.c][wepoll.c]) and a single header file ([wepoll.h][wepoll.h]).
-Compile the .c file as part of your project, and include the header wherever -needed. - -## Compatibility - -* Requires Windows Vista or higher. -* Can be compiled with recent versions of MSVC, Clang, and GCC. - -## API - -### General remarks - -* The epoll port is a `HANDLE`, not a file descriptor. -* All functions set both `errno` and `GetLastError()` on failure. -* For more extensive documentation, see the [epoll(7) man page][man epoll], - and the per-function man pages that are linked below. - -### epoll_create/epoll_create1 - -```c -HANDLE epoll_create(int size); -HANDLE epoll_create1(int flags); -``` - -* Create a new epoll instance (port). -* `size` is ignored but most be greater than zero. -* `flags` must be zero as there are no supported flags. -* Returns `NULL` on failure. -* [Linux man page][man epoll_create] - -### epoll_close - -```c -int epoll_close(HANDLE ephnd); -``` - -* Close an epoll port. -* Do not attempt to close the epoll port with `close()`, - `CloseHandle()` or `closesocket()`. - -### epoll_ctl - -```c -int epoll_ctl(HANDLE ephnd, - int op, - SOCKET sock, - struct epoll_event* event); -``` - -* Control which socket events are monitored by an epoll port. -* `ephnd` must be a HANDLE created by - [`epoll_create()`](#epoll_createepoll_create1) or - [`epoll_create1()`](#epoll_createepoll_create1). -* `op` must be one of `EPOLL_CTL_ADD`, `EPOLL_CTL_MOD`, `EPOLL_CTL_DEL`. -* `sock` must be a valid socket created by [`socket()`][msdn socket], - [`WSASocket()`][msdn wsasocket], or [`accept()`][msdn accept]. -* `event` should be a pointer to a [`struct epoll_event`](#struct-epoll_event).
- If `op` is `EPOLL_CTL_DEL` then the `event` parameter is ignored, and it - may be `NULL`. -* Returns 0 on success, -1 on failure. -* It is recommended to always explicitly remove a socket from its epoll - set using `EPOLL_CTL_DEL` *before* closing it.
- As on Linux, closed sockets are automatically removed from the epoll set, but - wepoll may not be able to detect that a socket was closed until the next call - to [`epoll_wait()`](#epoll_wait). -* [Linux man page][man epoll_ctl] - -### epoll_wait - -```c -int epoll_wait(HANDLE ephnd, - struct epoll_event* events, - int maxevents, - int timeout); -``` - -* Receive socket events from an epoll port. -* `events` should point to a caller-allocated array of - [`epoll_event`](#struct-epoll_event) structs, which will receive the - reported events. -* `maxevents` is the maximum number of events that will be written to the - `events` array, and must be greater than zero. -* `timeout` specifies whether to block when no events are immediately available. - - `<0` block indefinitely - - `0` report any events that are already waiting, but don't block - - `≥1` block for at most N milliseconds -* Return value: - - `-1` an error occurred - - `0` timed out without any events to report - - `≥1` the number of events stored in the `events` buffer -* [Linux man page][man epoll_wait] - -### struct epoll_event - -```c -typedef union epoll_data { - void* ptr; - int fd; - uint32_t u32; - uint64_t u64; - SOCKET sock; /* Windows specific */ - HANDLE hnd; /* Windows specific */ -} epoll_data_t; -``` - -```c -struct epoll_event { - uint32_t events; /* Epoll events and flags */ - epoll_data_t data; /* User data variable */ -}; -``` - -* The `events` field is a bit mask containing the events being - monitored/reported, and optional flags.
- Flags are accepted by [`epoll_ctl()`](#epoll_ctl), but they are not reported - back by [`epoll_wait()`](#epoll_wait). -* The `data` field can be used to associate application-specific information - with a socket; its value will be returned unmodified by - [`epoll_wait()`](#epoll_wait). -* [Linux man page][man epoll_ctl] - -| Event | Description | -|---------------|----------------------------------------------------------------------| -| `EPOLLIN` | incoming data available, or incoming connection ready to be accepted | -| `EPOLLOUT` | ready to send data, or outgoing connection successfully established | -| `EPOLLRDHUP` | remote peer initiated graceful socket shutdown | -| `EPOLLPRI` | out-of-band data available for reading | -| `EPOLLERR` | socket error1 | -| `EPOLLHUP` | socket hang-up1 | -| `EPOLLRDNORM` | same as `EPOLLIN` | -| `EPOLLRDBAND` | same as `EPOLLPRI` | -| `EPOLLWRNORM` | same as `EPOLLOUT` | -| `EPOLLWRBAND` | same as `EPOLLOUT` | -| `EPOLLMSG` | never reported | - -| Flag | Description | -|------------------|---------------------------| -| `EPOLLONESHOT` | report event(s) only once | -| `EPOLLET` | not supported by wepoll | -| `EPOLLEXCLUSIVE` | not supported by wepoll | -| `EPOLLWAKEUP` | not supported by wepoll | - -1: the `EPOLLERR` and `EPOLLHUP` events may always be reported by -[`epoll_wait()`](#epoll_wait), regardless of the event mask that was passed to -[`epoll_ctl()`](#epoll_ctl). - - -[ci status badge]: https://ci.appveyor.com/api/projects/status/github/piscisaureus/wepoll?branch=master&svg=true -[ci status link]: https://ci.appveyor.com/project/piscisaureus/wepoll/branch/master -[dist]: https://github.com/piscisaureus/wepoll/tree/dist -[man epoll]: http://man7.org/linux/man-pages/man7/epoll.7.html -[man epoll_create]: http://man7.org/linux/man-pages/man2/epoll_create.2.html -[man epoll_ctl]: http://man7.org/linux/man-pages/man2/epoll_ctl.2.html -[man epoll_wait]: http://man7.org/linux/man-pages/man2/epoll_wait.2.html -[msdn accept]: https://msdn.microsoft.com/en-us/library/windows/desktop/ms737526(v=vs.85).aspx -[msdn socket]: https://msdn.microsoft.com/en-us/library/windows/desktop/ms740506(v=vs.85).aspx -[msdn wsasocket]: https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx -[select scale]: https://daniel.haxx.se/docs/poll-vs-select.html -[wsapoll broken]: https://daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ -[wepoll.c]: https://github.com/piscisaureus/wepoll/blob/dist/wepoll.c -[wepoll.h]: https://github.com/piscisaureus/wepoll/blob/dist/wepoll.h diff --git a/3rdpart/wepoll/sys/epoll.cpp b/3rdpart/wepoll/sys/epoll.cpp deleted file mode 100644 index 1d9668b3..00000000 --- a/3rdpart/wepoll/sys/epoll.cpp +++ /dev/null @@ -1,14 +0,0 @@ -/* - * Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved. - * - * This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit). - * - * Use of this source code is governed by MIT license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ -#include "epoll.h" - -std::map toolkit::s_wepollHandleMap; -int toolkit::s_handleIndex = 0; -std::mutex toolkit::s_handleMtx; diff --git a/3rdpart/wepoll/sys/epoll.h b/3rdpart/wepoll/sys/epoll.h deleted file mode 100644 index f40c5af7..00000000 --- a/3rdpart/wepoll/sys/epoll.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2016 The ZLToolKit project authors. All Rights Reserved. - * - * This file is part of ZLToolKit(https://github.com/ZLMediaKit/ZLToolKit). - * - * Use of this source code is governed by MIT license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_EPOLL_H -#define ZLMEDIAKIT_EPOLL_H -#include "wepoll.h" -#include -#include - -// 屏蔽 EPOLLET -#define EPOLLET 0 - -namespace toolkit { -// 索引handle -extern std::map s_wepollHandleMap; -extern int s_handleIndex; -extern std::mutex s_handleMtx; -// 屏蔽epoll_create epoll_ctl epoll_wait参数差异 -inline int epoll_create(int size) { - HANDLE handle = ::epoll_create(size); - if (!handle) { - return -1; - } - { - std::lock_guard lck(s_handleMtx); - int idx = ++s_handleIndex; - s_wepollHandleMap[idx] = handle; - return idx; - } -} - -inline int epoll_ctl(int ephnd, int op, SOCKET sock, struct epoll_event *ev) { - HANDLE handle; - { - std::lock_guard lck(s_handleMtx); - handle = s_wepollHandleMap[ephnd]; - } - return ::epoll_ctl(handle, op, sock, ev); -} - -inline int epoll_wait(int ephnd, struct epoll_event *events, int maxevents, int timeout) { - HANDLE handle; - { - std::lock_guard lck(s_handleMtx); - handle = s_wepollHandleMap[ephnd]; - } - return ::epoll_wait(handle, events, maxevents, timeout); -} - -} // namespace toolkit - -#endif // ZLMEDIAKIT_EPOLL_H diff --git a/3rdpart/wepoll/wepoll.c b/3rdpart/wepoll/wepoll.c deleted file mode 100644 index 03cdc2bb..00000000 --- a/3rdpart/wepoll/wepoll.c +++ /dev/null @@ -1,2060 +0,0 @@ -/* - * wepoll - epoll for Windows - * https://github.com/piscisaureus/wepoll - * - * Copyright 2012-2020, Bert Belder - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef WEPOLL_EXPORT -#define WEPOLL_EXPORT -#endif - -#include - -enum EPOLL_EVENTS { - EPOLLIN = (int)(1U << 0), - EPOLLPRI = (int)(1U << 1), - EPOLLOUT = (int)(1U << 2), - EPOLLERR = (int)(1U << 3), - EPOLLHUP = (int)(1U << 4), - EPOLLRDNORM = (int)(1U << 6), - EPOLLRDBAND = (int)(1U << 7), - EPOLLWRNORM = (int)(1U << 8), - EPOLLWRBAND = (int)(1U << 9), - EPOLLMSG = (int)(1U << 10), /* Never reported. */ - EPOLLRDHUP = (int)(1U << 13), - EPOLLONESHOT = (int)(1U << 31) -}; - -#define EPOLLIN (1U << 0) -#define EPOLLPRI (1U << 1) -#define EPOLLOUT (1U << 2) -#define EPOLLERR (1U << 3) -#define EPOLLHUP (1U << 4) -#define EPOLLRDNORM (1U << 6) -#define EPOLLRDBAND (1U << 7) -#define EPOLLWRNORM (1U << 8) -#define EPOLLWRBAND (1U << 9) -#define EPOLLMSG (1U << 10) -#define EPOLLRDHUP (1U << 13) -#define EPOLLONESHOT (1U << 31) - -#define EPOLL_CTL_ADD 1 -#define EPOLL_CTL_MOD 2 -#define EPOLL_CTL_DEL 3 - -typedef void *HANDLE; -typedef uintptr_t SOCKET; - -typedef union epoll_data { - void *ptr; - int fd; - uint32_t u32; - uint64_t u64; - SOCKET sock; /* Windows specific */ - HANDLE hnd; /* Windows specific */ -} epoll_data_t; - -struct epoll_event { - uint32_t events; /* Epoll events and flags */ - epoll_data_t data; /* User data variable */ -}; - -#ifdef __cplusplus -extern "C" { -#endif - -WEPOLL_EXPORT HANDLE epoll_create(int size); -WEPOLL_EXPORT HANDLE epoll_create1(int flags); - -WEPOLL_EXPORT int epoll_close(HANDLE ephnd); - -WEPOLL_EXPORT int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event *event); - -WEPOLL_EXPORT int epoll_wait(HANDLE ephnd, struct epoll_event *events, int maxevents, int timeout); - -#ifdef __cplusplus -} /* extern "C" */ -#endif - -#include - -#include - -#define WEPOLL_INTERNAL static -#define WEPOLL_INTERNAL_EXTERN static - -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnonportable-system-include-path" -#pragma clang diagnostic ignored "-Wreserved-id-macro" -#elif defined(_MSC_VER) -#pragma warning(push, 1) -#endif - -#undef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN - -#undef _WIN32_WINNT -#define _WIN32_WINNT 0x0600 - -#include -#include -#include - -#if defined(__clang__) -#pragma clang diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - -WEPOLL_INTERNAL int nt_global_init(void); - -typedef LONG NTSTATUS; -typedef NTSTATUS *PNTSTATUS; - -#ifndef NT_SUCCESS -#define NT_SUCCESS(status) (((NTSTATUS)(status)) >= 0) -#endif - -#ifndef STATUS_SUCCESS -#define STATUS_SUCCESS ((NTSTATUS)0x00000000L) -#endif - -#ifndef STATUS_PENDING -#define STATUS_PENDING ((NTSTATUS)0x00000103L) -#endif - -#ifndef STATUS_CANCELLED -#define STATUS_CANCELLED ((NTSTATUS)0xC0000120L) -#endif - -#ifndef STATUS_NOT_FOUND -#define STATUS_NOT_FOUND ((NTSTATUS)0xC0000225L) -#endif - -typedef struct _IO_STATUS_BLOCK { - NTSTATUS Status; - ULONG_PTR Information; -} IO_STATUS_BLOCK, *PIO_STATUS_BLOCK; - -typedef VOID(NTAPI *PIO_APC_ROUTINE)(PVOID ApcContext, PIO_STATUS_BLOCK IoStatusBlock, ULONG Reserved); - -typedef struct _UNICODE_STRING { - USHORT Length; - USHORT MaximumLength; - PWSTR Buffer; -} UNICODE_STRING, *PUNICODE_STRING; - -#define RTL_CONSTANT_STRING(s) \ - { sizeof(s) - sizeof((s)[0]), sizeof(s), s } - -typedef struct _OBJECT_ATTRIBUTES { - ULONG Length; - HANDLE RootDirectory; - PUNICODE_STRING ObjectName; - ULONG Attributes; - PVOID SecurityDescriptor; - PVOID SecurityQualityOfService; -} OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES; - -#define RTL_CONSTANT_OBJECT_ATTRIBUTES(ObjectName, Attributes) \ - { sizeof(OBJECT_ATTRIBUTES), NULL, ObjectName, Attributes, NULL, NULL } - -#ifndef FILE_OPEN -#define FILE_OPEN 0x00000001UL -#endif - -#define KEYEDEVENT_WAIT 0x00000001UL -#define KEYEDEVENT_WAKE 0x00000002UL -#define KEYEDEVENT_ALL_ACCESS (STANDARD_RIGHTS_REQUIRED | KEYEDEVENT_WAIT | KEYEDEVENT_WAKE) - -#define NT_NTDLL_IMPORT_LIST(X) \ - X(NTSTATUS, NTAPI, NtCancelIoFileEx, \ - (HANDLE FileHandle, PIO_STATUS_BLOCK IoRequestToCancel, PIO_STATUS_BLOCK IoStatusBlock)) \ - \ - X(NTSTATUS, NTAPI, NtCreateFile, \ - (PHANDLE FileHandle, ACCESS_MASK DesiredAccess, POBJECT_ATTRIBUTES ObjectAttributes, \ - PIO_STATUS_BLOCK IoStatusBlock, PLARGE_INTEGER AllocationSize, ULONG FileAttributes, ULONG ShareAccess, \ - ULONG CreateDisposition, ULONG CreateOptions, PVOID EaBuffer, ULONG EaLength)) \ - \ - X(NTSTATUS, NTAPI, NtCreateKeyedEvent, \ - (PHANDLE KeyedEventHandle, ACCESS_MASK DesiredAccess, POBJECT_ATTRIBUTES ObjectAttributes, ULONG Flags)) \ - \ - X(NTSTATUS, NTAPI, NtDeviceIoControlFile, \ - (HANDLE FileHandle, HANDLE Event, PIO_APC_ROUTINE ApcRoutine, PVOID ApcContext, PIO_STATUS_BLOCK IoStatusBlock, \ - ULONG IoControlCode, PVOID InputBuffer, ULONG InputBufferLength, PVOID OutputBuffer, ULONG OutputBufferLength)) \ - \ - X(NTSTATUS, NTAPI, NtReleaseKeyedEvent, \ - (HANDLE KeyedEventHandle, PVOID KeyValue, BOOLEAN Alertable, PLARGE_INTEGER Timeout)) \ - \ - X(NTSTATUS, NTAPI, NtWaitForKeyedEvent, \ - (HANDLE KeyedEventHandle, PVOID KeyValue, BOOLEAN Alertable, PLARGE_INTEGER Timeout)) \ - \ - X(ULONG, WINAPI, RtlNtStatusToDosError, (NTSTATUS Status)) - -#define X(return_type, attributes, name, parameters) WEPOLL_INTERNAL_EXTERN return_type(attributes *name) parameters; -NT_NTDLL_IMPORT_LIST(X) -#undef X - -#define AFD_POLL_RECEIVE 0x0001 -#define AFD_POLL_RECEIVE_EXPEDITED 0x0002 -#define AFD_POLL_SEND 0x0004 -#define AFD_POLL_DISCONNECT 0x0008 -#define AFD_POLL_ABORT 0x0010 -#define AFD_POLL_LOCAL_CLOSE 0x0020 -#define AFD_POLL_ACCEPT 0x0080 -#define AFD_POLL_CONNECT_FAIL 0x0100 - -typedef struct _AFD_POLL_HANDLE_INFO { - HANDLE Handle; - ULONG Events; - NTSTATUS Status; -} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO; - -typedef struct _AFD_POLL_INFO { - LARGE_INTEGER Timeout; - ULONG NumberOfHandles; - ULONG Exclusive; - AFD_POLL_HANDLE_INFO Handles[1]; -} AFD_POLL_INFO, *PAFD_POLL_INFO; - -WEPOLL_INTERNAL int afd_create_device_handle(HANDLE iocp_handle, HANDLE *afd_device_handle_out); - -WEPOLL_INTERNAL int afd_poll(HANDLE afd_device_handle, AFD_POLL_INFO *poll_info, IO_STATUS_BLOCK *io_status_block); -WEPOLL_INTERNAL int afd_cancel_poll(HANDLE afd_device_handle, IO_STATUS_BLOCK *io_status_block); - -#define return_map_error(value) \ - do { \ - err_map_win_error(); \ - return (value); \ - } while (0) - -#define return_set_error(value, error) \ - do { \ - err_set_win_error(error); \ - return (value); \ - } while (0) - -WEPOLL_INTERNAL void err_map_win_error(void); -WEPOLL_INTERNAL void err_set_win_error(DWORD error); -WEPOLL_INTERNAL int err_check_handle(HANDLE handle); - -#define IOCTL_AFD_POLL 0x00012024 - -static UNICODE_STRING afd__device_name = RTL_CONSTANT_STRING(L"\\Device\\Afd\\Wepoll"); - -static OBJECT_ATTRIBUTES afd__device_attributes = RTL_CONSTANT_OBJECT_ATTRIBUTES(&afd__device_name, 0); - -int afd_create_device_handle(HANDLE iocp_handle, HANDLE *afd_device_handle_out) { - HANDLE afd_device_handle; - IO_STATUS_BLOCK iosb; - NTSTATUS status; - - /* By opening \Device\Afd without specifying any extended attributes, we'll - * get a handle that lets us talk to the AFD driver, but that doesn't have an - * associated endpoint (so it's not a socket). */ - status = NtCreateFile( - &afd_device_handle, SYNCHRONIZE, &afd__device_attributes, &iosb, NULL, 0, FILE_SHARE_READ | FILE_SHARE_WRITE, - FILE_OPEN, 0, NULL, 0); - if (status != STATUS_SUCCESS) - return_set_error(-1, RtlNtStatusToDosError(status)); - - if (CreateIoCompletionPort(afd_device_handle, iocp_handle, 0, 0) == NULL) - goto error; - - if (!SetFileCompletionNotificationModes(afd_device_handle, FILE_SKIP_SET_EVENT_ON_HANDLE)) - goto error; - - *afd_device_handle_out = afd_device_handle; - return 0; - -error: - CloseHandle(afd_device_handle); - return_map_error(-1); -} - -int afd_poll(HANDLE afd_device_handle, AFD_POLL_INFO *poll_info, IO_STATUS_BLOCK *io_status_block) { - NTSTATUS status; - - /* Blocking operation is not supported. */ - assert(io_status_block != NULL); - - io_status_block->Status = STATUS_PENDING; - status = NtDeviceIoControlFile( - afd_device_handle, NULL, NULL, io_status_block, io_status_block, IOCTL_AFD_POLL, poll_info, sizeof *poll_info, - poll_info, sizeof *poll_info); - - if (status == STATUS_SUCCESS) - return 0; - else if (status == STATUS_PENDING) - return_set_error(-1, ERROR_IO_PENDING); - else - return_set_error(-1, RtlNtStatusToDosError(status)); -} - -int afd_cancel_poll(HANDLE afd_device_handle, IO_STATUS_BLOCK *io_status_block) { - NTSTATUS cancel_status; - IO_STATUS_BLOCK cancel_iosb; - - /* If the poll operation has already completed or has been cancelled earlier, - * there's nothing left for us to do. */ - if (io_status_block->Status != STATUS_PENDING) - return 0; - - cancel_status = NtCancelIoFileEx(afd_device_handle, io_status_block, &cancel_iosb); - - /* NtCancelIoFileEx() may return STATUS_NOT_FOUND if the operation completed - * just before calling NtCancelIoFileEx(). This is not an error. */ - if (cancel_status == STATUS_SUCCESS || cancel_status == STATUS_NOT_FOUND) - return 0; - else - return_set_error(-1, RtlNtStatusToDosError(cancel_status)); -} - -WEPOLL_INTERNAL int epoll_global_init(void); - -WEPOLL_INTERNAL int init(void); - -typedef struct port_state port_state_t; -typedef struct queue queue_t; -typedef struct sock_state sock_state_t; -typedef struct ts_tree_node ts_tree_node_t; - -WEPOLL_INTERNAL port_state_t *port_new(HANDLE *iocp_handle_out); -WEPOLL_INTERNAL int port_close(port_state_t *port_state); -WEPOLL_INTERNAL int port_delete(port_state_t *port_state); - -WEPOLL_INTERNAL int port_wait(port_state_t *port_state, struct epoll_event *events, int maxevents, int timeout); - -WEPOLL_INTERNAL int port_ctl(port_state_t *port_state, int op, SOCKET sock, struct epoll_event *ev); - -WEPOLL_INTERNAL int port_register_socket(port_state_t *port_state, sock_state_t *sock_state, SOCKET socket); -WEPOLL_INTERNAL void port_unregister_socket(port_state_t *port_state, sock_state_t *sock_state); -WEPOLL_INTERNAL sock_state_t *port_find_socket(port_state_t *port_state, SOCKET socket); - -WEPOLL_INTERNAL void port_request_socket_update(port_state_t *port_state, sock_state_t *sock_state); -WEPOLL_INTERNAL void port_cancel_socket_update(port_state_t *port_state, sock_state_t *sock_state); - -WEPOLL_INTERNAL void port_add_deleted_socket(port_state_t *port_state, sock_state_t *sock_state); -WEPOLL_INTERNAL void port_remove_deleted_socket(port_state_t *port_state, sock_state_t *sock_state); - -WEPOLL_INTERNAL HANDLE port_get_iocp_handle(port_state_t *port_state); -WEPOLL_INTERNAL queue_t *port_get_poll_group_queue(port_state_t *port_state); - -WEPOLL_INTERNAL port_state_t *port_state_from_handle_tree_node(ts_tree_node_t *tree_node); -WEPOLL_INTERNAL ts_tree_node_t *port_state_to_handle_tree_node(port_state_t *port_state); - -/* The reflock is a special kind of lock that normally prevents a chunk of - * memory from being freed, but does allow the chunk of memory to eventually be - * released in a coordinated fashion. - * - * Under normal operation, threads increase and decrease the reference count, - * which are wait-free operations. - * - * Exactly once during the reflock's lifecycle, a thread holding a reference to - * the lock may "destroy" the lock; this operation blocks until all other - * threads holding a reference to the lock have dereferenced it. After - * "destroy" returns, the calling thread may assume that no other threads have - * a reference to the lock. - * - * Attemmpting to lock or destroy a lock after reflock_unref_and_destroy() has - * been called is invalid and results in undefined behavior. Therefore the user - * should use another lock to guarantee that this can't happen. - */ - -typedef struct reflock { - volatile long state; /* 32-bit Interlocked APIs operate on `long` values. */ -} reflock_t; - -WEPOLL_INTERNAL int reflock_global_init(void); - -WEPOLL_INTERNAL void reflock_init(reflock_t *reflock); -WEPOLL_INTERNAL void reflock_ref(reflock_t *reflock); -WEPOLL_INTERNAL void reflock_unref(reflock_t *reflock); -WEPOLL_INTERNAL void reflock_unref_and_destroy(reflock_t *reflock); - -#include - -/* N.b.: the tree functions do not set errno or LastError when they fail. Each - * of the API functions has at most one failure mode. It is up to the caller to - * set an appropriate error code when necessary. */ - -typedef struct tree tree_t; -typedef struct tree_node tree_node_t; - -typedef struct tree { - tree_node_t *root; -} tree_t; - -typedef struct tree_node { - tree_node_t *left; - tree_node_t *right; - tree_node_t *parent; - uintptr_t key; - bool red; -} tree_node_t; - -WEPOLL_INTERNAL void tree_init(tree_t *tree); -WEPOLL_INTERNAL void tree_node_init(tree_node_t *node); - -WEPOLL_INTERNAL int tree_add(tree_t *tree, tree_node_t *node, uintptr_t key); -WEPOLL_INTERNAL void tree_del(tree_t *tree, tree_node_t *node); - -WEPOLL_INTERNAL tree_node_t *tree_find(const tree_t *tree, uintptr_t key); -WEPOLL_INTERNAL tree_node_t *tree_root(const tree_t *tree); - -typedef struct ts_tree { - tree_t tree; - SRWLOCK lock; -} ts_tree_t; - -typedef struct ts_tree_node { - tree_node_t tree_node; - reflock_t reflock; -} ts_tree_node_t; - -WEPOLL_INTERNAL void ts_tree_init(ts_tree_t *rtl); -WEPOLL_INTERNAL void ts_tree_node_init(ts_tree_node_t *node); - -WEPOLL_INTERNAL int ts_tree_add(ts_tree_t *ts_tree, ts_tree_node_t *node, uintptr_t key); - -WEPOLL_INTERNAL ts_tree_node_t *ts_tree_del_and_ref(ts_tree_t *ts_tree, uintptr_t key); -WEPOLL_INTERNAL ts_tree_node_t *ts_tree_find_and_ref(ts_tree_t *ts_tree, uintptr_t key); - -WEPOLL_INTERNAL void ts_tree_node_unref(ts_tree_node_t *node); -WEPOLL_INTERNAL void ts_tree_node_unref_and_destroy(ts_tree_node_t *node); - -static ts_tree_t epoll__handle_tree; - -int epoll_global_init(void) { - ts_tree_init(&epoll__handle_tree); - return 0; -} - -static HANDLE epoll__create(void) { - port_state_t *port_state; - HANDLE ephnd; - ts_tree_node_t *tree_node; - - if (init() < 0) - return NULL; - - port_state = port_new(&ephnd); - if (port_state == NULL) - return NULL; - - tree_node = port_state_to_handle_tree_node(port_state); - if (ts_tree_add(&epoll__handle_tree, tree_node, (uintptr_t)ephnd) < 0) { - /* This should never happen. */ - port_delete(port_state); - return_set_error(NULL, ERROR_ALREADY_EXISTS); - } - - return ephnd; -} - -HANDLE epoll_create(int size) { - if (size <= 0) - return_set_error(NULL, ERROR_INVALID_PARAMETER); - - return epoll__create(); -} - -HANDLE epoll_create1(int flags) { - if (flags != 0) - return_set_error(NULL, ERROR_INVALID_PARAMETER); - - return epoll__create(); -} - -int epoll_close(HANDLE ephnd) { - ts_tree_node_t *tree_node; - port_state_t *port_state; - - if (init() < 0) - return -1; - - tree_node = ts_tree_del_and_ref(&epoll__handle_tree, (uintptr_t)ephnd); - if (tree_node == NULL) { - err_set_win_error(ERROR_INVALID_PARAMETER); - goto err; - } - - port_state = port_state_from_handle_tree_node(tree_node); - port_close(port_state); - - ts_tree_node_unref_and_destroy(tree_node); - - return port_delete(port_state); - -err: - err_check_handle(ephnd); - return -1; -} - -int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event *ev) { - ts_tree_node_t *tree_node; - port_state_t *port_state; - int r; - - if (init() < 0) - return -1; - - tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t)ephnd); - if (tree_node == NULL) { - err_set_win_error(ERROR_INVALID_PARAMETER); - goto err; - } - - port_state = port_state_from_handle_tree_node(tree_node); - r = port_ctl(port_state, op, sock, ev); - - ts_tree_node_unref(tree_node); - - if (r < 0) - goto err; - - return 0; - -err: - /* On Linux, in the case of epoll_ctl(), EBADF takes priority over other - * errors. Wepoll mimics this behavior. */ - err_check_handle(ephnd); - err_check_handle((HANDLE)sock); - return -1; -} - -int epoll_wait(HANDLE ephnd, struct epoll_event *events, int maxevents, int timeout) { - ts_tree_node_t *tree_node; - port_state_t *port_state; - int num_events; - - if (maxevents <= 0) - return_set_error(-1, ERROR_INVALID_PARAMETER); - - if (init() < 0) - return -1; - - tree_node = ts_tree_find_and_ref(&epoll__handle_tree, (uintptr_t)ephnd); - if (tree_node == NULL) { - err_set_win_error(ERROR_INVALID_PARAMETER); - goto err; - } - - port_state = port_state_from_handle_tree_node(tree_node); - num_events = port_wait(port_state, events, maxevents, timeout); - - ts_tree_node_unref(tree_node); - - if (num_events < 0) - goto err; - - return num_events; - -err: - err_check_handle(ephnd); - return -1; -} - -#include - -#define ERR__ERRNO_MAPPINGS(X) \ - X(ERROR_ACCESS_DENIED, EACCES) \ - X(ERROR_ALREADY_EXISTS, EEXIST) \ - X(ERROR_BAD_COMMAND, EACCES) \ - X(ERROR_BAD_EXE_FORMAT, ENOEXEC) \ - X(ERROR_BAD_LENGTH, EACCES) \ - X(ERROR_BAD_NETPATH, ENOENT) \ - X(ERROR_BAD_NET_NAME, ENOENT) \ - X(ERROR_BAD_NET_RESP, ENETDOWN) \ - X(ERROR_BAD_PATHNAME, ENOENT) \ - X(ERROR_BROKEN_PIPE, EPIPE) \ - X(ERROR_CANNOT_MAKE, EACCES) \ - X(ERROR_COMMITMENT_LIMIT, ENOMEM) \ - X(ERROR_CONNECTION_ABORTED, ECONNABORTED) \ - X(ERROR_CONNECTION_ACTIVE, EISCONN) \ - X(ERROR_CONNECTION_REFUSED, ECONNREFUSED) \ - X(ERROR_CRC, EACCES) \ - X(ERROR_DIR_NOT_EMPTY, ENOTEMPTY) \ - X(ERROR_DISK_FULL, ENOSPC) \ - X(ERROR_DUP_NAME, EADDRINUSE) \ - X(ERROR_FILENAME_EXCED_RANGE, ENOENT) \ - X(ERROR_FILE_NOT_FOUND, ENOENT) \ - X(ERROR_GEN_FAILURE, EACCES) \ - X(ERROR_GRACEFUL_DISCONNECT, EPIPE) \ - X(ERROR_HOST_DOWN, EHOSTUNREACH) \ - X(ERROR_HOST_UNREACHABLE, EHOSTUNREACH) \ - X(ERROR_INSUFFICIENT_BUFFER, EFAULT) \ - X(ERROR_INVALID_ADDRESS, EADDRNOTAVAIL) \ - X(ERROR_INVALID_FUNCTION, EINVAL) \ - X(ERROR_INVALID_HANDLE, EBADF) \ - X(ERROR_INVALID_NETNAME, EADDRNOTAVAIL) \ - X(ERROR_INVALID_PARAMETER, EINVAL) \ - X(ERROR_INVALID_USER_BUFFER, EMSGSIZE) \ - X(ERROR_IO_PENDING, EINPROGRESS) \ - X(ERROR_LOCK_VIOLATION, EACCES) \ - X(ERROR_MORE_DATA, EMSGSIZE) \ - X(ERROR_NETNAME_DELETED, ECONNABORTED) \ - X(ERROR_NETWORK_ACCESS_DENIED, EACCES) \ - X(ERROR_NETWORK_BUSY, ENETDOWN) \ - X(ERROR_NETWORK_UNREACHABLE, ENETUNREACH) \ - X(ERROR_NOACCESS, EFAULT) \ - X(ERROR_NONPAGED_SYSTEM_RESOURCES, ENOMEM) \ - X(ERROR_NOT_ENOUGH_MEMORY, ENOMEM) \ - X(ERROR_NOT_ENOUGH_QUOTA, ENOMEM) \ - X(ERROR_NOT_FOUND, ENOENT) \ - X(ERROR_NOT_LOCKED, EACCES) \ - X(ERROR_NOT_READY, EACCES) \ - X(ERROR_NOT_SAME_DEVICE, EXDEV) \ - X(ERROR_NOT_SUPPORTED, ENOTSUP) \ - X(ERROR_NO_MORE_FILES, ENOENT) \ - X(ERROR_NO_SYSTEM_RESOURCES, ENOMEM) \ - X(ERROR_OPERATION_ABORTED, EINTR) \ - X(ERROR_OUT_OF_PAPER, EACCES) \ - X(ERROR_PAGED_SYSTEM_RESOURCES, ENOMEM) \ - X(ERROR_PAGEFILE_QUOTA, ENOMEM) \ - X(ERROR_PATH_NOT_FOUND, ENOENT) \ - X(ERROR_PIPE_NOT_CONNECTED, EPIPE) \ - X(ERROR_PORT_UNREACHABLE, ECONNRESET) \ - X(ERROR_PROTOCOL_UNREACHABLE, ENETUNREACH) \ - X(ERROR_REM_NOT_LIST, ECONNREFUSED) \ - X(ERROR_REQUEST_ABORTED, EINTR) \ - X(ERROR_REQ_NOT_ACCEP, EWOULDBLOCK) \ - X(ERROR_SECTOR_NOT_FOUND, EACCES) \ - X(ERROR_SEM_TIMEOUT, ETIMEDOUT) \ - X(ERROR_SHARING_VIOLATION, EACCES) \ - X(ERROR_TOO_MANY_NAMES, ENOMEM) \ - X(ERROR_TOO_MANY_OPEN_FILES, EMFILE) \ - X(ERROR_UNEXP_NET_ERR, ECONNABORTED) \ - X(ERROR_WAIT_NO_CHILDREN, ECHILD) \ - X(ERROR_WORKING_SET_QUOTA, ENOMEM) \ - X(ERROR_WRITE_PROTECT, EACCES) \ - X(ERROR_WRONG_DISK, EACCES) \ - X(WSAEACCES, EACCES) \ - X(WSAEADDRINUSE, EADDRINUSE) \ - X(WSAEADDRNOTAVAIL, EADDRNOTAVAIL) \ - X(WSAEAFNOSUPPORT, EAFNOSUPPORT) \ - X(WSAECONNABORTED, ECONNABORTED) \ - X(WSAECONNREFUSED, ECONNREFUSED) \ - X(WSAECONNRESET, ECONNRESET) \ - X(WSAEDISCON, EPIPE) \ - X(WSAEFAULT, EFAULT) \ - X(WSAEHOSTDOWN, EHOSTUNREACH) \ - X(WSAEHOSTUNREACH, EHOSTUNREACH) \ - X(WSAEINPROGRESS, EBUSY) \ - X(WSAEINTR, EINTR) \ - X(WSAEINVAL, EINVAL) \ - X(WSAEISCONN, EISCONN) \ - X(WSAEMSGSIZE, EMSGSIZE) \ - X(WSAENETDOWN, ENETDOWN) \ - X(WSAENETRESET, EHOSTUNREACH) \ - X(WSAENETUNREACH, ENETUNREACH) \ - X(WSAENOBUFS, ENOMEM) \ - X(WSAENOTCONN, ENOTCONN) \ - X(WSAENOTSOCK, ENOTSOCK) \ - X(WSAEOPNOTSUPP, EOPNOTSUPP) \ - X(WSAEPROCLIM, ENOMEM) \ - X(WSAESHUTDOWN, EPIPE) \ - X(WSAETIMEDOUT, ETIMEDOUT) \ - X(WSAEWOULDBLOCK, EWOULDBLOCK) \ - X(WSANOTINITIALISED, ENETDOWN) \ - X(WSASYSNOTREADY, ENETDOWN) \ - X(WSAVERNOTSUPPORTED, ENOSYS) - -static errno_t err__map_win_error_to_errno(DWORD error) { - switch (error) { -#define X(error_sym, errno_sym) \ - case error_sym: return errno_sym; - ERR__ERRNO_MAPPINGS(X) -#undef X - } - return EINVAL; -} - -void err_map_win_error(void) { - errno = err__map_win_error_to_errno(GetLastError()); -} - -void err_set_win_error(DWORD error) { - SetLastError(error); - errno = err__map_win_error_to_errno(error); -} - -int err_check_handle(HANDLE handle) { - DWORD flags; - - /* GetHandleInformation() succeeds when passed INVALID_HANDLE_VALUE, so check - * for this condition explicitly. */ - if (handle == INVALID_HANDLE_VALUE) - return_set_error(-1, ERROR_INVALID_HANDLE); - - if (!GetHandleInformation(handle, &flags)) - return_map_error(-1); - - return 0; -} - -#include - -#define array_count(a) (sizeof(a) / (sizeof((a)[0]))) - -#define container_of(ptr, type, member) ((type *)((uintptr_t)(ptr)-offsetof(type, member))) - -#define unused_var(v) ((void)(v)) - -/* Polyfill `inline` for older versions of msvc (up to Visual Studio 2013) */ -#if defined(_MSC_VER) && _MSC_VER < 1900 -#define inline __inline -#endif - -WEPOLL_INTERNAL int ws_global_init(void); -WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket); - -static bool init__done = false; -static INIT_ONCE init__once = INIT_ONCE_STATIC_INIT; - -static BOOL CALLBACK init__once_callback(INIT_ONCE *once, void *parameter, void **context) { - unused_var(once); - unused_var(parameter); - unused_var(context); - - /* N.b. that initialization order matters here. */ - if (ws_global_init() < 0 || nt_global_init() < 0 || reflock_global_init() < 0 || epoll_global_init() < 0) - return FALSE; - - init__done = true; - return TRUE; -} - -int init(void) { - if (!init__done && !InitOnceExecuteOnce(&init__once, init__once_callback, NULL, NULL)) - /* `InitOnceExecuteOnce()` itself is infallible, and it doesn't set any - * error code when the once-callback returns FALSE. We return -1 here to - * indicate that global initialization failed; the failing init function is - * resposible for setting `errno` and calling `SetLastError()`. */ - return -1; - - return 0; -} - -/* Set up a workaround for the following problem: - * FARPROC addr = GetProcAddress(...); - * MY_FUNC func = (MY_FUNC) addr; <-- GCC 8 warning/error. - * MY_FUNC func = (MY_FUNC) (void*) addr; <-- MSVC warning/error. - * To compile cleanly with either compiler, do casts with this "bridge" type: - * MY_FUNC func = (MY_FUNC) (nt__fn_ptr_cast_t) addr; */ -#ifdef __GNUC__ -typedef void *nt__fn_ptr_cast_t; -#else -typedef FARPROC nt__fn_ptr_cast_t; -#endif - -#define X(return_type, attributes, name, parameters) WEPOLL_INTERNAL return_type(attributes *name) parameters = NULL; -NT_NTDLL_IMPORT_LIST(X) -#undef X - -int nt_global_init(void) { - HMODULE ntdll; - FARPROC fn_ptr; - - ntdll = GetModuleHandleW(L"ntdll.dll"); - if (ntdll == NULL) - return -1; - -#define X(return_type, attributes, name, parameters) \ - fn_ptr = GetProcAddress(ntdll, #name); \ - if (fn_ptr == NULL) \ - return -1; \ - name = (return_type(attributes *) parameters)(nt__fn_ptr_cast_t)fn_ptr; - NT_NTDLL_IMPORT_LIST(X) -#undef X - - return 0; -} - -#include - -typedef struct poll_group poll_group_t; - -typedef struct queue_node queue_node_t; - -WEPOLL_INTERNAL poll_group_t *poll_group_acquire(port_state_t *port); -WEPOLL_INTERNAL void poll_group_release(poll_group_t *poll_group); - -WEPOLL_INTERNAL void poll_group_delete(poll_group_t *poll_group); - -WEPOLL_INTERNAL poll_group_t *poll_group_from_queue_node(queue_node_t *queue_node); -WEPOLL_INTERNAL HANDLE poll_group_get_afd_device_handle(poll_group_t *poll_group); - -typedef struct queue_node { - queue_node_t *prev; - queue_node_t *next; -} queue_node_t; - -typedef struct queue { - queue_node_t head; -} queue_t; - -WEPOLL_INTERNAL void queue_init(queue_t *queue); -WEPOLL_INTERNAL void queue_node_init(queue_node_t *node); - -WEPOLL_INTERNAL queue_node_t *queue_first(const queue_t *queue); -WEPOLL_INTERNAL queue_node_t *queue_last(const queue_t *queue); - -WEPOLL_INTERNAL void queue_prepend(queue_t *queue, queue_node_t *node); -WEPOLL_INTERNAL void queue_append(queue_t *queue, queue_node_t *node); -WEPOLL_INTERNAL void queue_move_to_start(queue_t *queue, queue_node_t *node); -WEPOLL_INTERNAL void queue_move_to_end(queue_t *queue, queue_node_t *node); -WEPOLL_INTERNAL void queue_remove(queue_node_t *node); - -WEPOLL_INTERNAL bool queue_is_empty(const queue_t *queue); -WEPOLL_INTERNAL bool queue_is_enqueued(const queue_node_t *node); - -#define POLL_GROUP__MAX_GROUP_SIZE 32 - -typedef struct poll_group { - port_state_t *port_state; - queue_node_t queue_node; - HANDLE afd_device_handle; - size_t group_size; -} poll_group_t; - -static poll_group_t *poll_group__new(port_state_t *port_state) { - HANDLE iocp_handle = port_get_iocp_handle(port_state); - queue_t *poll_group_queue = port_get_poll_group_queue(port_state); - - poll_group_t *poll_group = malloc(sizeof *poll_group); - if (poll_group == NULL) - return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY); - - memset(poll_group, 0, sizeof *poll_group); - - queue_node_init(&poll_group->queue_node); - poll_group->port_state = port_state; - - if (afd_create_device_handle(iocp_handle, &poll_group->afd_device_handle) < 0) { - free(poll_group); - return NULL; - } - - queue_append(poll_group_queue, &poll_group->queue_node); - - return poll_group; -} - -void poll_group_delete(poll_group_t *poll_group) { - assert(poll_group->group_size == 0); - CloseHandle(poll_group->afd_device_handle); - queue_remove(&poll_group->queue_node); - free(poll_group); -} - -poll_group_t *poll_group_from_queue_node(queue_node_t *queue_node) { - return container_of(queue_node, poll_group_t, queue_node); -} - -HANDLE poll_group_get_afd_device_handle(poll_group_t *poll_group) { - return poll_group->afd_device_handle; -} - -poll_group_t *poll_group_acquire(port_state_t *port_state) { - queue_t *poll_group_queue = port_get_poll_group_queue(port_state); - poll_group_t *poll_group = !queue_is_empty(poll_group_queue) - ? container_of(queue_last(poll_group_queue), poll_group_t, queue_node) - : NULL; - - if (poll_group == NULL || poll_group->group_size >= POLL_GROUP__MAX_GROUP_SIZE) - poll_group = poll_group__new(port_state); - if (poll_group == NULL) - return NULL; - - if (++poll_group->group_size == POLL_GROUP__MAX_GROUP_SIZE) - queue_move_to_start(poll_group_queue, &poll_group->queue_node); - - return poll_group; -} - -void poll_group_release(poll_group_t *poll_group) { - port_state_t *port_state = poll_group->port_state; - queue_t *poll_group_queue = port_get_poll_group_queue(port_state); - - poll_group->group_size--; - assert(poll_group->group_size < POLL_GROUP__MAX_GROUP_SIZE); - - queue_move_to_end(poll_group_queue, &poll_group->queue_node); - - /* Poll groups are currently only freed when the epoll port is closed. */ -} - -WEPOLL_INTERNAL sock_state_t *sock_new(port_state_t *port_state, SOCKET socket); -WEPOLL_INTERNAL void sock_delete(port_state_t *port_state, sock_state_t *sock_state); -WEPOLL_INTERNAL void sock_force_delete(port_state_t *port_state, sock_state_t *sock_state); - -WEPOLL_INTERNAL int sock_set_event(port_state_t *port_state, sock_state_t *sock_state, const struct epoll_event *ev); - -WEPOLL_INTERNAL int sock_update(port_state_t *port_state, sock_state_t *sock_state); -WEPOLL_INTERNAL int sock_feed_event(port_state_t *port_state, IO_STATUS_BLOCK *io_status_block, struct epoll_event *ev); - -WEPOLL_INTERNAL sock_state_t *sock_state_from_queue_node(queue_node_t *queue_node); -WEPOLL_INTERNAL queue_node_t *sock_state_to_queue_node(sock_state_t *sock_state); -WEPOLL_INTERNAL sock_state_t *sock_state_from_tree_node(tree_node_t *tree_node); -WEPOLL_INTERNAL tree_node_t *sock_state_to_tree_node(sock_state_t *sock_state); - -#define PORT__MAX_ON_STACK_COMPLETIONS 256 - -typedef struct port_state { - HANDLE iocp_handle; - tree_t sock_tree; - queue_t sock_update_queue; - queue_t sock_deleted_queue; - queue_t poll_group_queue; - ts_tree_node_t handle_tree_node; - CRITICAL_SECTION lock; - size_t active_poll_count; -} port_state_t; - -static inline port_state_t *port__alloc(void) { - port_state_t *port_state = malloc(sizeof *port_state); - if (port_state == NULL) - return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY); - - return port_state; -} - -static inline void port__free(port_state_t *port) { - assert(port != NULL); - free(port); -} - -static inline HANDLE port__create_iocp(void) { - HANDLE iocp_handle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); - if (iocp_handle == NULL) - return_map_error(NULL); - - return iocp_handle; -} - -port_state_t *port_new(HANDLE *iocp_handle_out) { - port_state_t *port_state; - HANDLE iocp_handle; - - port_state = port__alloc(); - if (port_state == NULL) - goto err1; - - iocp_handle = port__create_iocp(); - if (iocp_handle == NULL) - goto err2; - - memset(port_state, 0, sizeof *port_state); - - port_state->iocp_handle = iocp_handle; - tree_init(&port_state->sock_tree); - queue_init(&port_state->sock_update_queue); - queue_init(&port_state->sock_deleted_queue); - queue_init(&port_state->poll_group_queue); - ts_tree_node_init(&port_state->handle_tree_node); - InitializeCriticalSection(&port_state->lock); - - *iocp_handle_out = iocp_handle; - return port_state; - -err2: - port__free(port_state); -err1: - return NULL; -} - -static inline int port__close_iocp(port_state_t *port_state) { - HANDLE iocp_handle = port_state->iocp_handle; - port_state->iocp_handle = NULL; - - if (!CloseHandle(iocp_handle)) - return_map_error(-1); - - return 0; -} - -int port_close(port_state_t *port_state) { - int result; - - EnterCriticalSection(&port_state->lock); - result = port__close_iocp(port_state); - LeaveCriticalSection(&port_state->lock); - - return result; -} - -int port_delete(port_state_t *port_state) { - tree_node_t *tree_node; - queue_node_t *queue_node; - - /* At this point the IOCP port should have been closed. */ - assert(port_state->iocp_handle == NULL); - - while ((tree_node = tree_root(&port_state->sock_tree)) != NULL) { - sock_state_t *sock_state = sock_state_from_tree_node(tree_node); - sock_force_delete(port_state, sock_state); - } - - while ((queue_node = queue_first(&port_state->sock_deleted_queue)) != NULL) { - sock_state_t *sock_state = sock_state_from_queue_node(queue_node); - sock_force_delete(port_state, sock_state); - } - - while ((queue_node = queue_first(&port_state->poll_group_queue)) != NULL) { - poll_group_t *poll_group = poll_group_from_queue_node(queue_node); - poll_group_delete(poll_group); - } - - assert(queue_is_empty(&port_state->sock_update_queue)); - - DeleteCriticalSection(&port_state->lock); - - port__free(port_state); - - return 0; -} - -static int port__update_events(port_state_t *port_state) { - queue_t *sock_update_queue = &port_state->sock_update_queue; - - /* Walk the queue, submitting new poll requests for every socket that needs - * it. */ - while (!queue_is_empty(sock_update_queue)) { - queue_node_t *queue_node = queue_first(sock_update_queue); - sock_state_t *sock_state = sock_state_from_queue_node(queue_node); - - if (sock_update(port_state, sock_state) < 0) - return -1; - - /* sock_update() removes the socket from the update queue. */ - } - - return 0; -} - -static inline void port__update_events_if_polling(port_state_t *port_state) { - if (port_state->active_poll_count > 0) - port__update_events(port_state); -} - -static inline int port__feed_events( - port_state_t *port_state, struct epoll_event *epoll_events, OVERLAPPED_ENTRY *iocp_events, DWORD iocp_event_count) { - int epoll_event_count = 0; - DWORD i; - - for (i = 0; i < iocp_event_count; i++) { - IO_STATUS_BLOCK *io_status_block = (IO_STATUS_BLOCK *)iocp_events[i].lpOverlapped; - struct epoll_event *ev = &epoll_events[epoll_event_count]; - - epoll_event_count += sock_feed_event(port_state, io_status_block, ev); - } - - return epoll_event_count; -} - -static inline int port__poll( - port_state_t *port_state, struct epoll_event *epoll_events, OVERLAPPED_ENTRY *iocp_events, DWORD maxevents, - DWORD timeout) { - DWORD completion_count; - - if (port__update_events(port_state) < 0) - return -1; - - port_state->active_poll_count++; - - LeaveCriticalSection(&port_state->lock); - - BOOL r = GetQueuedCompletionStatusEx( - port_state->iocp_handle, iocp_events, maxevents, &completion_count, timeout, FALSE); - - EnterCriticalSection(&port_state->lock); - - port_state->active_poll_count--; - - if (!r) - return_map_error(-1); - - return port__feed_events(port_state, epoll_events, iocp_events, completion_count); -} - -int port_wait(port_state_t *port_state, struct epoll_event *events, int maxevents, int timeout) { - OVERLAPPED_ENTRY stack_iocp_events[PORT__MAX_ON_STACK_COMPLETIONS]; - OVERLAPPED_ENTRY *iocp_events; - uint64_t due = 0; - DWORD gqcs_timeout; - int result; - - /* Check whether `maxevents` is in range. */ - if (maxevents <= 0) - return_set_error(-1, ERROR_INVALID_PARAMETER); - - /* Decide whether the IOCP completion list can live on the stack, or allocate - * memory for it on the heap. */ - if ((size_t)maxevents <= array_count(stack_iocp_events)) { - iocp_events = stack_iocp_events; - } else if ((iocp_events = malloc((size_t)maxevents * sizeof *iocp_events)) == NULL) { - iocp_events = stack_iocp_events; - maxevents = array_count(stack_iocp_events); - } - - /* Compute the timeout for GetQueuedCompletionStatus, and the wait end - * time, if the user specified a timeout other than zero or infinite. */ - if (timeout > 0) { - due = GetTickCount64() + (uint64_t)timeout; - gqcs_timeout = (DWORD)timeout; - } else if (timeout == 0) { - gqcs_timeout = 0; - } else { - gqcs_timeout = INFINITE; - } - - EnterCriticalSection(&port_state->lock); - - /* Dequeue completion packets until either at least one interesting event - * has been discovered, or the timeout is reached. */ - for (;;) { - uint64_t now; - - result = port__poll(port_state, events, iocp_events, (DWORD)maxevents, gqcs_timeout); - if (result < 0 || result > 0) - break; /* Result, error, or time-out. */ - - if (timeout < 0) - continue; /* When timeout is negative, never time out. */ - - /* Update time. */ - now = GetTickCount64(); - - /* Do not allow the due time to be in the past. */ - if (now >= due) { - SetLastError(WAIT_TIMEOUT); - break; - } - - /* Recompute time-out argument for GetQueuedCompletionStatus. */ - gqcs_timeout = (DWORD)(due - now); - } - - port__update_events_if_polling(port_state); - - LeaveCriticalSection(&port_state->lock); - - if (iocp_events != stack_iocp_events) - free(iocp_events); - - if (result >= 0) - return result; - else if (GetLastError() == WAIT_TIMEOUT) - return 0; - else - return -1; -} - -static inline int port__ctl_add(port_state_t *port_state, SOCKET sock, struct epoll_event *ev) { - sock_state_t *sock_state = sock_new(port_state, sock); - if (sock_state == NULL) - return -1; - - if (sock_set_event(port_state, sock_state, ev) < 0) { - sock_delete(port_state, sock_state); - return -1; - } - - port__update_events_if_polling(port_state); - - return 0; -} - -static inline int port__ctl_mod(port_state_t *port_state, SOCKET sock, struct epoll_event *ev) { - sock_state_t *sock_state = port_find_socket(port_state, sock); - if (sock_state == NULL) - return -1; - - if (sock_set_event(port_state, sock_state, ev) < 0) - return -1; - - port__update_events_if_polling(port_state); - - return 0; -} - -static inline int port__ctl_del(port_state_t *port_state, SOCKET sock) { - sock_state_t *sock_state = port_find_socket(port_state, sock); - if (sock_state == NULL) - return -1; - - sock_delete(port_state, sock_state); - - return 0; -} - -static inline int port__ctl_op(port_state_t *port_state, int op, SOCKET sock, struct epoll_event *ev) { - switch (op) { - case EPOLL_CTL_ADD: return port__ctl_add(port_state, sock, ev); - case EPOLL_CTL_MOD: return port__ctl_mod(port_state, sock, ev); - case EPOLL_CTL_DEL: return port__ctl_del(port_state, sock); - default: return_set_error(-1, ERROR_INVALID_PARAMETER); - } -} - -int port_ctl(port_state_t *port_state, int op, SOCKET sock, struct epoll_event *ev) { - int result; - - EnterCriticalSection(&port_state->lock); - result = port__ctl_op(port_state, op, sock, ev); - LeaveCriticalSection(&port_state->lock); - - return result; -} - -int port_register_socket(port_state_t *port_state, sock_state_t *sock_state, SOCKET socket) { - if (tree_add(&port_state->sock_tree, sock_state_to_tree_node(sock_state), socket) < 0) - return_set_error(-1, ERROR_ALREADY_EXISTS); - return 0; -} - -void port_unregister_socket(port_state_t *port_state, sock_state_t *sock_state) { - tree_del(&port_state->sock_tree, sock_state_to_tree_node(sock_state)); -} - -sock_state_t *port_find_socket(port_state_t *port_state, SOCKET socket) { - tree_node_t *tree_node = tree_find(&port_state->sock_tree, socket); - if (tree_node == NULL) - return_set_error(NULL, ERROR_NOT_FOUND); - return sock_state_from_tree_node(tree_node); -} - -void port_request_socket_update(port_state_t *port_state, sock_state_t *sock_state) { - if (queue_is_enqueued(sock_state_to_queue_node(sock_state))) - return; - queue_append(&port_state->sock_update_queue, sock_state_to_queue_node(sock_state)); -} - -void port_cancel_socket_update(port_state_t *port_state, sock_state_t *sock_state) { - unused_var(port_state); - if (!queue_is_enqueued(sock_state_to_queue_node(sock_state))) - return; - queue_remove(sock_state_to_queue_node(sock_state)); -} - -void port_add_deleted_socket(port_state_t *port_state, sock_state_t *sock_state) { - if (queue_is_enqueued(sock_state_to_queue_node(sock_state))) - return; - queue_append(&port_state->sock_deleted_queue, sock_state_to_queue_node(sock_state)); -} - -void port_remove_deleted_socket(port_state_t *port_state, sock_state_t *sock_state) { - unused_var(port_state); - if (!queue_is_enqueued(sock_state_to_queue_node(sock_state))) - return; - queue_remove(sock_state_to_queue_node(sock_state)); -} - -HANDLE port_get_iocp_handle(port_state_t *port_state) { - assert(port_state->iocp_handle != NULL); - return port_state->iocp_handle; -} - -queue_t *port_get_poll_group_queue(port_state_t *port_state) { - return &port_state->poll_group_queue; -} - -port_state_t *port_state_from_handle_tree_node(ts_tree_node_t *tree_node) { - return container_of(tree_node, port_state_t, handle_tree_node); -} - -ts_tree_node_t *port_state_to_handle_tree_node(port_state_t *port_state) { - return &port_state->handle_tree_node; -} - -void queue_init(queue_t *queue) { - queue_node_init(&queue->head); -} - -void queue_node_init(queue_node_t *node) { - node->prev = node; - node->next = node; -} - -static inline void queue__detach_node(queue_node_t *node) { - node->prev->next = node->next; - node->next->prev = node->prev; -} - -queue_node_t *queue_first(const queue_t *queue) { - return !queue_is_empty(queue) ? queue->head.next : NULL; -} - -queue_node_t *queue_last(const queue_t *queue) { - return !queue_is_empty(queue) ? queue->head.prev : NULL; -} - -void queue_prepend(queue_t *queue, queue_node_t *node) { - node->next = queue->head.next; - node->prev = &queue->head; - node->next->prev = node; - queue->head.next = node; -} - -void queue_append(queue_t *queue, queue_node_t *node) { - node->next = &queue->head; - node->prev = queue->head.prev; - node->prev->next = node; - queue->head.prev = node; -} - -void queue_move_to_start(queue_t *queue, queue_node_t *node) { - queue__detach_node(node); - queue_prepend(queue, node); -} - -void queue_move_to_end(queue_t *queue, queue_node_t *node) { - queue__detach_node(node); - queue_append(queue, node); -} - -void queue_remove(queue_node_t *node) { - queue__detach_node(node); - queue_node_init(node); -} - -bool queue_is_empty(const queue_t *queue) { - return !queue_is_enqueued(&queue->head); -} - -bool queue_is_enqueued(const queue_node_t *node) { - return node->prev != node; -} - -#define REFLOCK__REF ((long)0x00000001UL) -#define REFLOCK__REF_MASK ((long)0x0fffffffUL) -#define REFLOCK__DESTROY ((long)0x10000000UL) -#define REFLOCK__DESTROY_MASK ((long)0xf0000000UL) -#define REFLOCK__POISON ((long)0x300dead0UL) - -static HANDLE reflock__keyed_event = NULL; - -int reflock_global_init(void) { - NTSTATUS status = NtCreateKeyedEvent(&reflock__keyed_event, KEYEDEVENT_ALL_ACCESS, NULL, 0); - if (status != STATUS_SUCCESS) - return_set_error(-1, RtlNtStatusToDosError(status)); - return 0; -} - -void reflock_init(reflock_t *reflock) { - reflock->state = 0; -} - -static void reflock__signal_event(void *address) { - NTSTATUS status = NtReleaseKeyedEvent(reflock__keyed_event, address, FALSE, NULL); - if (status != STATUS_SUCCESS) - abort(); -} - -static void reflock__await_event(void *address) { - NTSTATUS status = NtWaitForKeyedEvent(reflock__keyed_event, address, FALSE, NULL); - if (status != STATUS_SUCCESS) - abort(); -} - -void reflock_ref(reflock_t *reflock) { - long state = InterlockedAdd(&reflock->state, REFLOCK__REF); - - /* Verify that the counter didn't overflow and the lock isn't destroyed. */ - assert((state & REFLOCK__DESTROY_MASK) == 0); - unused_var(state); -} - -void reflock_unref(reflock_t *reflock) { - long state = InterlockedAdd(&reflock->state, -REFLOCK__REF); - - /* Verify that the lock was referenced and not already destroyed. */ - assert((state & REFLOCK__DESTROY_MASK & ~REFLOCK__DESTROY) == 0); - - if (state == REFLOCK__DESTROY) - reflock__signal_event(reflock); -} - -void reflock_unref_and_destroy(reflock_t *reflock) { - long state = InterlockedAdd(&reflock->state, REFLOCK__DESTROY - REFLOCK__REF); - long ref_count = state & REFLOCK__REF_MASK; - - /* Verify that the lock was referenced and not already destroyed. */ - assert((state & REFLOCK__DESTROY_MASK) == REFLOCK__DESTROY); - - if (ref_count != 0) - reflock__await_event(reflock); - - state = InterlockedExchange(&reflock->state, REFLOCK__POISON); - assert(state == REFLOCK__DESTROY); -} - -#define SOCK__KNOWN_EPOLL_EVENTS \ - (EPOLLIN | EPOLLPRI | EPOLLOUT | EPOLLERR | EPOLLHUP | EPOLLRDNORM | EPOLLRDBAND | EPOLLWRNORM | EPOLLWRBAND \ - | EPOLLMSG | EPOLLRDHUP) - -typedef enum sock__poll_status { SOCK__POLL_IDLE = 0, SOCK__POLL_PENDING, SOCK__POLL_CANCELLED } sock__poll_status_t; - -typedef struct sock_state { - IO_STATUS_BLOCK io_status_block; - AFD_POLL_INFO poll_info; - queue_node_t queue_node; - tree_node_t tree_node; - poll_group_t *poll_group; - SOCKET base_socket; - epoll_data_t user_data; - uint32_t user_events; - uint32_t pending_events; - sock__poll_status_t poll_status; - bool delete_pending; -} sock_state_t; - -static inline sock_state_t *sock__alloc(void) { - sock_state_t *sock_state = malloc(sizeof *sock_state); - if (sock_state == NULL) - return_set_error(NULL, ERROR_NOT_ENOUGH_MEMORY); - return sock_state; -} - -static inline void sock__free(sock_state_t *sock_state) { - assert(sock_state != NULL); - free(sock_state); -} - -static inline int sock__cancel_poll(sock_state_t *sock_state) { - assert(sock_state->poll_status == SOCK__POLL_PENDING); - - if (afd_cancel_poll(poll_group_get_afd_device_handle(sock_state->poll_group), &sock_state->io_status_block) < 0) - return -1; - - sock_state->poll_status = SOCK__POLL_CANCELLED; - sock_state->pending_events = 0; - return 0; -} - -sock_state_t *sock_new(port_state_t *port_state, SOCKET socket) { - SOCKET base_socket; - poll_group_t *poll_group; - sock_state_t *sock_state; - - if (socket == 0 || socket == INVALID_SOCKET) - return_set_error(NULL, ERROR_INVALID_HANDLE); - - base_socket = ws_get_base_socket(socket); - if (base_socket == INVALID_SOCKET) - return NULL; - - poll_group = poll_group_acquire(port_state); - if (poll_group == NULL) - return NULL; - - sock_state = sock__alloc(); - if (sock_state == NULL) - goto err1; - - memset(sock_state, 0, sizeof *sock_state); - - sock_state->base_socket = base_socket; - sock_state->poll_group = poll_group; - - tree_node_init(&sock_state->tree_node); - queue_node_init(&sock_state->queue_node); - - if (port_register_socket(port_state, sock_state, socket) < 0) - goto err2; - - return sock_state; - -err2: - sock__free(sock_state); -err1: - poll_group_release(poll_group); - - return NULL; -} - -static int sock__delete(port_state_t *port_state, sock_state_t *sock_state, bool force) { - if (!sock_state->delete_pending) { - if (sock_state->poll_status == SOCK__POLL_PENDING) - sock__cancel_poll(sock_state); - - port_cancel_socket_update(port_state, sock_state); - port_unregister_socket(port_state, sock_state); - - sock_state->delete_pending = true; - } - - /* If the poll request still needs to complete, the sock_state object can't - * be free()d yet. `sock_feed_event()` or `port_close()` will take care - * of this later. */ - if (force || sock_state->poll_status == SOCK__POLL_IDLE) { - /* Free the sock_state now. */ - port_remove_deleted_socket(port_state, sock_state); - poll_group_release(sock_state->poll_group); - sock__free(sock_state); - } else { - /* Free the socket later. */ - port_add_deleted_socket(port_state, sock_state); - } - - return 0; -} - -void sock_delete(port_state_t *port_state, sock_state_t *sock_state) { - sock__delete(port_state, sock_state, false); -} - -void sock_force_delete(port_state_t *port_state, sock_state_t *sock_state) { - sock__delete(port_state, sock_state, true); -} - -int sock_set_event(port_state_t *port_state, sock_state_t *sock_state, const struct epoll_event *ev) { - /* EPOLLERR and EPOLLHUP are always reported, even when not requested by the - * caller. However they are disabled after a event has been reported for a - * socket for which the EPOLLONESHOT flag was set. */ - uint32_t events = ev->events | EPOLLERR | EPOLLHUP; - - sock_state->user_events = events; - sock_state->user_data = ev->data; - - if ((events & SOCK__KNOWN_EPOLL_EVENTS & ~sock_state->pending_events) != 0) - port_request_socket_update(port_state, sock_state); - - return 0; -} - -static inline DWORD sock__epoll_events_to_afd_events(uint32_t epoll_events) { - /* Always monitor for AFD_POLL_LOCAL_CLOSE, which is triggered when the - * socket is closed with closesocket() or CloseHandle(). */ - DWORD afd_events = AFD_POLL_LOCAL_CLOSE; - - if (epoll_events & (EPOLLIN | EPOLLRDNORM)) - afd_events |= AFD_POLL_RECEIVE | AFD_POLL_ACCEPT; - if (epoll_events & (EPOLLPRI | EPOLLRDBAND)) - afd_events |= AFD_POLL_RECEIVE_EXPEDITED; - if (epoll_events & (EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND)) - afd_events |= AFD_POLL_SEND; - if (epoll_events & (EPOLLIN | EPOLLRDNORM | EPOLLRDHUP)) - afd_events |= AFD_POLL_DISCONNECT; - if (epoll_events & EPOLLHUP) - afd_events |= AFD_POLL_ABORT; - if (epoll_events & EPOLLERR) - afd_events |= AFD_POLL_CONNECT_FAIL; - - return afd_events; -} - -static inline uint32_t sock__afd_events_to_epoll_events(DWORD afd_events) { - uint32_t epoll_events = 0; - - if (afd_events & (AFD_POLL_RECEIVE | AFD_POLL_ACCEPT)) - epoll_events |= EPOLLIN | EPOLLRDNORM; - if (afd_events & AFD_POLL_RECEIVE_EXPEDITED) - epoll_events |= EPOLLPRI | EPOLLRDBAND; - if (afd_events & AFD_POLL_SEND) - epoll_events |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; - if (afd_events & AFD_POLL_DISCONNECT) - epoll_events |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; - if (afd_events & AFD_POLL_ABORT) - epoll_events |= EPOLLHUP; - if (afd_events & AFD_POLL_CONNECT_FAIL) - /* Linux reports all these events after connect() has failed. */ - epoll_events |= EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLRDNORM | EPOLLWRNORM | EPOLLRDHUP; - - return epoll_events; -} - -int sock_update(port_state_t *port_state, sock_state_t *sock_state) { - assert(!sock_state->delete_pending); - - if ((sock_state->poll_status == SOCK__POLL_PENDING) - && (sock_state->user_events & SOCK__KNOWN_EPOLL_EVENTS & ~sock_state->pending_events) == 0) { - /* All the events the user is interested in are already being monitored by - * the pending poll operation. It might spuriously complete because of an - * event that we're no longer interested in; when that happens we'll submit - * a new poll operation with the updated event mask. */ - - } else if (sock_state->poll_status == SOCK__POLL_PENDING) { - /* A poll operation is already pending, but it's not monitoring for all the - * events that the user is interested in. Therefore, cancel the pending - * poll operation; when we receive it's completion package, a new poll - * operation will be submitted with the correct event mask. */ - if (sock__cancel_poll(sock_state) < 0) - return -1; - - } else if (sock_state->poll_status == SOCK__POLL_CANCELLED) { - /* The poll operation has already been cancelled, we're still waiting for - * it to return. For now, there's nothing that needs to be done. */ - - } else if (sock_state->poll_status == SOCK__POLL_IDLE) { - /* No poll operation is pending; start one. */ - sock_state->poll_info.Exclusive = FALSE; - sock_state->poll_info.NumberOfHandles = 1; - sock_state->poll_info.Timeout.QuadPart = INT64_MAX; - sock_state->poll_info.Handles[0].Handle = (HANDLE)sock_state->base_socket; - sock_state->poll_info.Handles[0].Status = 0; - sock_state->poll_info.Handles[0].Events = sock__epoll_events_to_afd_events(sock_state->user_events); - - if (afd_poll( - poll_group_get_afd_device_handle(sock_state->poll_group), &sock_state->poll_info, - &sock_state->io_status_block) - < 0) { - switch (GetLastError()) { - case ERROR_IO_PENDING: - /* Overlapped poll operation in progress; this is expected. */ - break; - case ERROR_INVALID_HANDLE: - /* Socket closed; it'll be dropped from the epoll set. */ - return sock__delete(port_state, sock_state, false); - default: - /* Other errors are propagated to the caller. */ - return_map_error(-1); - } - } - - /* The poll request was successfully submitted. */ - sock_state->poll_status = SOCK__POLL_PENDING; - sock_state->pending_events = sock_state->user_events; - - } else { - /* Unreachable. */ - assert(false); - } - - port_cancel_socket_update(port_state, sock_state); - return 0; -} - -int sock_feed_event(port_state_t *port_state, IO_STATUS_BLOCK *io_status_block, struct epoll_event *ev) { - sock_state_t *sock_state = container_of(io_status_block, sock_state_t, io_status_block); - AFD_POLL_INFO *poll_info = &sock_state->poll_info; - uint32_t epoll_events = 0; - - sock_state->poll_status = SOCK__POLL_IDLE; - sock_state->pending_events = 0; - - if (sock_state->delete_pending) { - /* Socket has been deleted earlier and can now be freed. */ - return sock__delete(port_state, sock_state, false); - - } else if (io_status_block->Status == STATUS_CANCELLED) { - /* The poll request was cancelled by CancelIoEx. */ - - } else if (!NT_SUCCESS(io_status_block->Status)) { - /* The overlapped request itself failed in an unexpected way. */ - epoll_events = EPOLLERR; - - } else if (poll_info->NumberOfHandles < 1) { - /* This poll operation succeeded but didn't report any socket events. */ - - } else if (poll_info->Handles[0].Events & AFD_POLL_LOCAL_CLOSE) { - /* The poll operation reported that the socket was closed. */ - return sock__delete(port_state, sock_state, false); - - } else { - /* Events related to our socket were reported. */ - epoll_events = sock__afd_events_to_epoll_events(poll_info->Handles[0].Events); - } - - /* Requeue the socket so a new poll request will be submitted. */ - port_request_socket_update(port_state, sock_state); - - /* Filter out events that the user didn't ask for. */ - epoll_events &= sock_state->user_events; - - /* Return if there are no epoll events to report. */ - if (epoll_events == 0) - return 0; - - /* If the the socket has the EPOLLONESHOT flag set, unmonitor all events, - * even EPOLLERR and EPOLLHUP. But always keep looking for closed sockets. */ - if (sock_state->user_events & EPOLLONESHOT) - sock_state->user_events = 0; - - ev->data = sock_state->user_data; - ev->events = epoll_events; - return 1; -} - -sock_state_t *sock_state_from_queue_node(queue_node_t *queue_node) { - return container_of(queue_node, sock_state_t, queue_node); -} - -queue_node_t *sock_state_to_queue_node(sock_state_t *sock_state) { - return &sock_state->queue_node; -} - -sock_state_t *sock_state_from_tree_node(tree_node_t *tree_node) { - return container_of(tree_node, sock_state_t, tree_node); -} - -tree_node_t *sock_state_to_tree_node(sock_state_t *sock_state) { - return &sock_state->tree_node; -} - -void ts_tree_init(ts_tree_t *ts_tree) { - tree_init(&ts_tree->tree); - InitializeSRWLock(&ts_tree->lock); -} - -void ts_tree_node_init(ts_tree_node_t *node) { - tree_node_init(&node->tree_node); - reflock_init(&node->reflock); -} - -int ts_tree_add(ts_tree_t *ts_tree, ts_tree_node_t *node, uintptr_t key) { - int r; - - AcquireSRWLockExclusive(&ts_tree->lock); - r = tree_add(&ts_tree->tree, &node->tree_node, key); - ReleaseSRWLockExclusive(&ts_tree->lock); - - return r; -} - -static inline ts_tree_node_t *ts_tree__find_node(ts_tree_t *ts_tree, uintptr_t key) { - tree_node_t *tree_node = tree_find(&ts_tree->tree, key); - if (tree_node == NULL) - return NULL; - - return container_of(tree_node, ts_tree_node_t, tree_node); -} - -ts_tree_node_t *ts_tree_del_and_ref(ts_tree_t *ts_tree, uintptr_t key) { - ts_tree_node_t *ts_tree_node; - - AcquireSRWLockExclusive(&ts_tree->lock); - - ts_tree_node = ts_tree__find_node(ts_tree, key); - if (ts_tree_node != NULL) { - tree_del(&ts_tree->tree, &ts_tree_node->tree_node); - reflock_ref(&ts_tree_node->reflock); - } - - ReleaseSRWLockExclusive(&ts_tree->lock); - - return ts_tree_node; -} - -ts_tree_node_t *ts_tree_find_and_ref(ts_tree_t *ts_tree, uintptr_t key) { - ts_tree_node_t *ts_tree_node; - - AcquireSRWLockShared(&ts_tree->lock); - - ts_tree_node = ts_tree__find_node(ts_tree, key); - if (ts_tree_node != NULL) - reflock_ref(&ts_tree_node->reflock); - - ReleaseSRWLockShared(&ts_tree->lock); - - return ts_tree_node; -} - -void ts_tree_node_unref(ts_tree_node_t *node) { - reflock_unref(&node->reflock); -} - -void ts_tree_node_unref_and_destroy(ts_tree_node_t *node) { - reflock_unref_and_destroy(&node->reflock); -} - -void tree_init(tree_t *tree) { - memset(tree, 0, sizeof *tree); -} - -void tree_node_init(tree_node_t *node) { - memset(node, 0, sizeof *node); -} - -#define TREE__ROTATE(cis, trans) \ - tree_node_t *p = node; \ - tree_node_t *q = node->trans; \ - tree_node_t *parent = p->parent; \ - \ - if (parent) { \ - if (parent->left == p) \ - parent->left = q; \ - else \ - parent->right = q; \ - } else { \ - tree->root = q; \ - } \ - \ - q->parent = parent; \ - p->parent = q; \ - p->trans = q->cis; \ - if (p->trans) \ - p->trans->parent = p; \ - q->cis = p; - -static inline void tree__rotate_left(tree_t *tree, tree_node_t *node) { - TREE__ROTATE(left, right) -} - -static inline void tree__rotate_right(tree_t *tree, tree_node_t *node) { - TREE__ROTATE(right, left) -} - -#define TREE__INSERT_OR_DESCEND(side) \ - if (parent->side) { \ - parent = parent->side; \ - } else { \ - parent->side = node; \ - break; \ - } - -#define TREE__REBALANCE_AFTER_INSERT(cis, trans) \ - tree_node_t *grandparent = parent->parent; \ - tree_node_t *uncle = grandparent->trans; \ - \ - if (uncle && uncle->red) { \ - parent->red = uncle->red = false; \ - grandparent->red = true; \ - node = grandparent; \ - } else { \ - if (node == parent->trans) { \ - tree__rotate_##cis(tree, parent); \ - node = parent; \ - parent = node->parent; \ - } \ - parent->red = false; \ - grandparent->red = true; \ - tree__rotate_##trans(tree, grandparent); \ - } - -int tree_add(tree_t *tree, tree_node_t *node, uintptr_t key) { - tree_node_t *parent; - - parent = tree->root; - if (parent) { - for (;;) { - if (key < parent->key) { - TREE__INSERT_OR_DESCEND(left) - } else if (key > parent->key) { - TREE__INSERT_OR_DESCEND(right) - } else { - return -1; - } - } - } else { - tree->root = node; - } - - node->key = key; - node->left = node->right = NULL; - node->parent = parent; - node->red = true; - - for (; parent && parent->red; parent = node->parent) { - if (parent == parent->parent->left) { - TREE__REBALANCE_AFTER_INSERT(left, right) - } else { - TREE__REBALANCE_AFTER_INSERT(right, left) - } - } - tree->root->red = false; - - return 0; -} - -#define TREE__REBALANCE_AFTER_REMOVE(cis, trans) \ - tree_node_t *sibling = parent->trans; \ - \ - if (sibling->red) { \ - sibling->red = false; \ - parent->red = true; \ - tree__rotate_##cis(tree, parent); \ - sibling = parent->trans; \ - } \ - if ((sibling->left && sibling->left->red) || (sibling->right && sibling->right->red)) { \ - if (!sibling->trans || !sibling->trans->red) { \ - sibling->cis->red = false; \ - sibling->red = true; \ - tree__rotate_##trans(tree, sibling); \ - sibling = parent->trans; \ - } \ - sibling->red = parent->red; \ - parent->red = sibling->trans->red = false; \ - tree__rotate_##cis(tree, parent); \ - node = tree->root; \ - break; \ - } \ - sibling->red = true; - -void tree_del(tree_t *tree, tree_node_t *node) { - tree_node_t *parent = node->parent; - tree_node_t *left = node->left; - tree_node_t *right = node->right; - tree_node_t *next; - bool red; - - if (!left) { - next = right; - } else if (!right) { - next = left; - } else { - next = right; - while (next->left) - next = next->left; - } - - if (parent) { - if (parent->left == node) - parent->left = next; - else - parent->right = next; - } else { - tree->root = next; - } - - if (left && right) { - red = next->red; - next->red = node->red; - next->left = left; - left->parent = next; - if (next != right) { - parent = next->parent; - next->parent = node->parent; - node = next->right; - parent->left = node; - next->right = right; - right->parent = next; - } else { - next->parent = parent; - parent = next; - node = next->right; - } - } else { - red = node->red; - node = next; - } - - if (node) - node->parent = parent; - if (red) - return; - if (node && node->red) { - node->red = false; - return; - } - - do { - if (node == tree->root) - break; - if (node == parent->left) { - TREE__REBALANCE_AFTER_REMOVE(left, right) - } else { - TREE__REBALANCE_AFTER_REMOVE(right, left) - } - node = parent; - parent = parent->parent; - } while (!node->red); - - if (node) - node->red = false; -} - -tree_node_t *tree_find(const tree_t *tree, uintptr_t key) { - tree_node_t *node = tree->root; - while (node) { - if (key < node->key) - node = node->left; - else if (key > node->key) - node = node->right; - else - return node; - } - return NULL; -} - -tree_node_t *tree_root(const tree_t *tree) { - return tree->root; -} - -#ifndef SIO_BSP_HANDLE_POLL -#define SIO_BSP_HANDLE_POLL 0x4800001D -#endif - -#ifndef SIO_BASE_HANDLE -#define SIO_BASE_HANDLE 0x48000022 -#endif - -int ws_global_init(void) { - int r; - WSADATA wsa_data; - - r = WSAStartup(MAKEWORD(2, 2), &wsa_data); - if (r != 0) - return_set_error(-1, (DWORD)r); - - return 0; -} - -static inline SOCKET ws__ioctl_get_bsp_socket(SOCKET socket, DWORD ioctl) { - SOCKET bsp_socket; - DWORD bytes; - - if (WSAIoctl(socket, ioctl, NULL, 0, &bsp_socket, sizeof bsp_socket, &bytes, NULL, NULL) != SOCKET_ERROR) - return bsp_socket; - else - return INVALID_SOCKET; -} - -SOCKET ws_get_base_socket(SOCKET socket) { - SOCKET base_socket; - DWORD error; - - for (;;) { - base_socket = ws__ioctl_get_bsp_socket(socket, SIO_BASE_HANDLE); - if (base_socket != INVALID_SOCKET) - return base_socket; - - error = GetLastError(); - if (error == WSAENOTSOCK) - return_set_error(INVALID_SOCKET, error); - - /* Even though Microsoft documentation clearly states that LSPs should - * never intercept the `SIO_BASE_HANDLE` ioctl [1], Komodia based LSPs do - * so anyway, breaking it, with the apparent intention of preventing LSP - * bypass [2]. Fortunately they don't handle `SIO_BSP_HANDLE_POLL`, which - * will at least let us obtain the socket associated with the next winsock - * protocol chain entry. If this succeeds, loop around and call - * `SIO_BASE_HANDLE` again with the returned BSP socket, to make sure that - * we unwrap all layers and retrieve the actual base socket. - * [1] https://docs.microsoft.com/en-us/windows/win32/winsock/winsock-ioctls - * [2] https://www.komodia.com/newwiki/index.php?title=Komodia%27s_Redirector_bug_fixes#Version_2.2.2.6 - */ - base_socket = ws__ioctl_get_bsp_socket(socket, SIO_BSP_HANDLE_POLL); - if (base_socket != INVALID_SOCKET && base_socket != socket) - socket = base_socket; - else - return_set_error(INVALID_SOCKET, error); - } -} diff --git a/3rdpart/wepoll/wepoll.h b/3rdpart/wepoll/wepoll.h deleted file mode 100644 index 7a512e88..00000000 --- a/3rdpart/wepoll/wepoll.h +++ /dev/null @@ -1,107 +0,0 @@ -/* - * wepoll - epoll for Windows - * https://github.com/piscisaureus/wepoll - * - * Copyright 2012-2020, Bert Belder - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef WEPOLL_H_ -#define WEPOLL_H_ - -#ifndef WEPOLL_EXPORT -#define WEPOLL_EXPORT -#endif - -#include - -enum EPOLL_EVENTS { - EPOLLIN = (int)(1U << 0), - EPOLLPRI = (int)(1U << 1), - EPOLLOUT = (int)(1U << 2), - EPOLLERR = (int)(1U << 3), - EPOLLHUP = (int)(1U << 4), - EPOLLRDNORM = (int)(1U << 6), - EPOLLRDBAND = (int)(1U << 7), - EPOLLWRNORM = (int)(1U << 8), - EPOLLWRBAND = (int)(1U << 9), - EPOLLMSG = (int)(1U << 10), /* Never reported. */ - EPOLLRDHUP = (int)(1U << 13), - EPOLLONESHOT = (int)(1U << 31) -}; - -#define EPOLLIN (1U << 0) -#define EPOLLPRI (1U << 1) -#define EPOLLOUT (1U << 2) -#define EPOLLERR (1U << 3) -#define EPOLLHUP (1U << 4) -#define EPOLLRDNORM (1U << 6) -#define EPOLLRDBAND (1U << 7) -#define EPOLLWRNORM (1U << 8) -#define EPOLLWRBAND (1U << 9) -#define EPOLLMSG (1U << 10) -#define EPOLLRDHUP (1U << 13) -#define EPOLLONESHOT (1U << 31) - -#define EPOLL_CTL_ADD 1 -#define EPOLL_CTL_MOD 2 -#define EPOLL_CTL_DEL 3 - -typedef void *HANDLE; -typedef uintptr_t SOCKET; - -typedef union epoll_data { - void *ptr; - int fd; - uint32_t u32; - uint64_t u64; - SOCKET sock; /* Windows specific */ - HANDLE hnd; /* Windows specific */ -} epoll_data_t; - -struct epoll_event { - uint32_t events; /* Epoll events and flags */ - epoll_data_t data; /* User data variable */ -}; - -#ifdef __cplusplus -extern "C" { -#endif - -WEPOLL_EXPORT HANDLE epoll_create(int size); -WEPOLL_EXPORT HANDLE epoll_create1(int flags); - -WEPOLL_EXPORT int epoll_close(HANDLE ephnd); - -WEPOLL_EXPORT int epoll_ctl(HANDLE ephnd, int op, SOCKET sock, struct epoll_event *event); - -WEPOLL_EXPORT int epoll_wait(HANDLE ephnd, struct epoll_event *events, int maxevents, int timeout); - -#ifdef __cplusplus -} /* extern "C" */ -#endif - -#endif /* WEPOLL_H_ */ diff --git a/AUTHORS b/AUTHORS index fcc45452..d5a5358a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -107,4 +107,22 @@ WuPeng [huangcaichun](https://github.com/huangcaichun) [jamesZHANG500](https://github.com/jamesZHANG500) [weidelong](https://github.com/wdl1697454803) -[小强先生](https://github.com/linshangqiang) \ No newline at end of file +[小强先生](https://github.com/linshangqiang) +[李之阳](https://github.com/leo94666) +[sgzed](https://github.com/sgzed) +[gaoshan](https://github.com/foobra) +[zhang2349](https://github.com/zhang2349) +[benshi](https://github.com/BenLocal) +[autoantwort](https://github.com/autoantwort) +[u7ko4](https://github.com/u7ko4) +[WengQiang](https://github.com/Tsubaki-01) +[wEnchanters](https://github.com/wEnchanters) +[sbkyy](https://github.com/sbkyy) +[wuxingzhong](https://github.com/wuxingzhong) +[286897655](https://github.com/286897655) +[ss002012](https://github.com/ss002012) +[a839419160](https://github.com/a839419160) +[oldma3095](https://github.com/oldma3095) +[Dary](https://github.com/watersounds) +[N.z](https://github.com/neesonqk) +[yanggs](https://github.com/callinglove) \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 97b37620..7223d4d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,7 +21,7 @@ # SOFTWARE. # -cmake_minimum_required(VERSION 3.1.3) +cmake_minimum_required(VERSION 3.1.3...3.26) # 加载自定义模块 # Load custom modules @@ -32,6 +32,8 @@ project(ZLMediaKit LANGUAGES C CXX) # 使能 C++11 # Enable C++11 set(CMAKE_CXX_STANDARD 11) +# -fPIC +set(CMAKE_POSITION_INDEPENDENT_CODE ON) option(ENABLE_API "Enable C API SDK" ON) option(ENABLE_API_STATIC_LIB "Enable mk_api static lib" OFF) @@ -42,6 +44,7 @@ option(ENABLE_FFMPEG "Enable FFmpeg" OFF) option(ENABLE_HLS "Enable HLS" ON) option(ENABLE_JEMALLOC_STATIC "Enable static linking to the jemalloc library" OFF) option(ENABLE_JEMALLOC_DUMP "Enable jemalloc to dump malloc statistics" OFF) +option(ENABLE_TCMALLOC "Enable linking to the tcmalloc library" OFF) option(ENABLE_MEM_DEBUG "Enable Memory Debug" OFF) option(ENABLE_MP4 "Enable MP4" ON) option(ENABLE_MSVC_MT "Enable MSVC Mt/Mtd lib" ON) @@ -56,10 +59,14 @@ option(ENABLE_TESTS "Enable Tests" ON) option(ENABLE_SCTP "Enable SCTP" ON) option(ENABLE_WEBRTC "Enable WebRTC" ON) option(ENABLE_X264 "Enable x264" OFF) -option(ENABLE_WEPOLL "Enable wepoll" ON) option(ENABLE_VIDEOSTACK "Enable video stack" OFF) option(DISABLE_REPORT "Disable report to report.zlmediakit.com" OFF) option(USE_SOLUTION_FOLDERS "Enable solution dir supported" ON) +option(ENABLE_OBJCOPY "Enable use objcopy to generate debug info file" ON) +# 编译静态库 +option(BUILD_SHARED_LIBS "Build shared instead of static" OFF) +option(ENABLE_PYTHON "Enable python plugin" OFF) + ############################################################################## # 设置socket默认缓冲区大小为256k.如果设置为0则不设置socket的默认缓冲区大小,使用系统内核默认值(设置为0仅对linux有效) # Set the default buffer size of the socket to 256k. If set to 0, the default buffer size of the socket will not be set, @@ -198,7 +205,10 @@ if(UNIX) if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") set(COMPILE_OPTIONS_DEFAULT ${COMPILE_OPTIONS_DEFAULT} "-g3") else() - set(COMPILE_OPTIONS_DEFAULT ${COMPILE_OPTIONS_DEFAULT} "-g0") + find_program(OBJCOPY_FOUND objcopy) + if (OBJCOPY_FOUND AND ENABLE_OBJCOPY) + set(COMPILE_OPTIONS_DEFAULT ${COMPILE_OPTIONS_DEFAULT} "-g3") + endif() endif() elseif(WIN32) if (MSVC) @@ -208,8 +218,8 @@ elseif(WIN32) # warning C4530: C++ exception handler used, but unwind semantics are not enabled. "/EHsc") # disable Windows logo - list(APPEND COMPILE_OPTIONS_DEFAULT "/nologo") - list(APPEND CMAKE_STATIC_LINKER_FLAGS "/nologo") + string(REPLACE "/nologo" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + set(CMAKE_STATIC_LINKER_FLAGS "") endif() endif() @@ -248,8 +258,8 @@ endif() # Multiple modules depend on ffmpeg related libraries, unified search if(ENABLE_FFMPEG) find_package(PkgConfig QUIET) - # 查找 ffmpeg/libutil 是否安装 - # find ffmpeg/libutil installed + # 查找 ffmpeg/libavutil 是否安装 + # find ffmpeg/libavutil installed if(PKG_CONFIG_FOUND) pkg_check_modules(AVUTIL QUIET IMPORTED_TARGET libavutil) if(AVUTIL_FOUND) @@ -288,8 +298,19 @@ if(ENABLE_FFMPEG) endif() endif() - # 查找 ffmpeg/libutil 是否安装 - # find ffmpeg/libutil installed + # 查找 ffmpeg/libavfilter 是否安装 + # find ffmpeg/libavfilter installed + if(PKG_CONFIG_FOUND) + pkg_check_modules(AVFILTER QUIET IMPORTED_TARGET libavfilter) + if(AVFILTER_FOUND) + update_cached_list(MK_LINK_LIBRARIES PkgConfig::AVFILTER) + message(STATUS "found library: ${AVFILTER_LIBRARIES}") + endif() + endif() + + + # 查找 ffmpeg/libavutil 是否安装 + # find ffmpeg/libavutil installed if(NOT AVUTIL_FOUND) find_package(AVUTIL QUIET) if(AVUTIL_FOUND) @@ -332,7 +353,16 @@ if(ENABLE_FFMPEG) endif() endif() - if(AVUTIL_FOUND AND AVCODEC_FOUND AND SWSCALE_FOUND AND SWRESAMPLE_FOUND) + if(NOT AVFILTER_FOUND) + find_package(AVFILTER QUIET) + if(AVFILTER_FOUND) + include_directories(SYSTEM ${AVFILTER_INCLUDE_DIR}) + update_cached_list(MK_LINK_LIBRARIES ${AVFILTER_LIBRARIES}) + message(STATUS "found library: ${AVFILTER_LIBRARIES}") + endif() + endif() + + if(AVUTIL_FOUND AND AVCODEC_FOUND AND SWSCALE_FOUND AND SWRESAMPLE_FOUND AND AVFILTER_FOUND) update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_FFMPEG) update_cached_list(MK_LINK_LIBRARIES ${CMAKE_DL_LIBS}) else() @@ -393,6 +423,19 @@ if(JEMALLOC_FOUND) endif () endif() +# 查找 tcmalloc 是否安装 +# find tcmalloc installed +if(ENABLE_TCMALLOC) + find_package(TCMALLOC QUIET) + if(TCMALLOC_FOUND) + message(STATUS "Link with tcmalloc library: ${TCMALLOC_LIBRARIES}") + update_cached_list(MK_LINK_LIBRARIES ${TCMALLOC_LIBRARIES}) + else() + set(ENABLE_TCMALLOC OFF) + message(WARNING "tcmalloc 相关功能未找到") + endif() +endif() + # 查找 openssl 是否安装 # find openssl installed find_package(OpenSSL QUIET) @@ -467,6 +510,17 @@ if(ENABLE_SRT) update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_SRT) endif() +if(ENABLE_WEBRTC) + # 查找 srtp 是否安装 + find_package(SRTP QUIET) + if(SRTP_FOUND AND ENABLE_OPENSSL) + message(STATUS "found library: ${SRTP_LIBRARIES}, ENABLE_WEBRTC defined") + update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_WEBRTC) + else() + set(ENABLE_WEBRTC OFF) + message(WARNING "srtp 未找到, WebRTC 相关功能打开失败") + endif() +endif() # ---------------------------------------------------------------------------- # Solution folders: # ---------------------------------------------------------------------------- @@ -551,6 +605,9 @@ file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/www" DESTINATION ${EXECUTABLE_OUTPUT_PATH file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/conf/config.ini" DESTINATION ${EXECUTABLE_OUTPUT_PATH}) file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/default.pem" DESTINATION ${EXECUTABLE_OUTPUT_PATH}) +if (ENABLE_FFMPEG) + file(COPY "${CMAKE_CURRENT_SOURCE_DIR}/DejaVuSans.ttf" DESTINATION ${EXECUTABLE_OUTPUT_PATH}) +endif () # 拷贝VideoStack 无视频流时默认填充的背景图片 # Copy the default background image used by VideoStack when there is no video stream if (ENABLE_VIDEOSTACK AND ENABLE_FFMPEG AND ENABLE_X264) diff --git a/DejaVuSans.ttf b/DejaVuSans.ttf new file mode 100644 index 00000000..2fbbe69e Binary files /dev/null and b/DejaVuSans.ttf differ diff --git a/README.md b/README.md index 05529d21..b77f1725 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ - [谁在使用zlmediakit?](https://github.com/ZLMediaKit/ZLMediaKit/issues/511) - 全面支持ipv6网络 - 支持多轨道模式(一个流中多个视频/音频) -- 全协议支持H264/H265/AAC/G711/OPUS/MP3,部分支持VP8/VP9/AV1/JPEG/MP3/H266/ADPCM/SVAC/G722/G723/G729 +- 全协议支持H264/H265/AAC/G711/OPUS/MP3/VP8/VP9/AV1,部分支持JPEG/H266/ADPCM/SVAC/G722/G723/G729/MP2 ## 项目定位 @@ -47,7 +47,7 @@ ## 功能清单 ### 功能一览 -功能一览 +功能预览 - RTSP[S] - RTSP[S] 服务器,支持RTMP/MP4/HLS转RTSP[S],支持亚马逊echo show这样的设备 @@ -57,7 +57,7 @@ - 服务器/客户端完整支持Basic/Digest方式的登录鉴权,全异步可配置化的鉴权接口 - 支持H265编码 - 服务器支持RTSP推流(包括`rtp over udp` `rtp over tcp`方式) - - 支持H264/H265/AAC/G711/OPUS/MJPEG/MP3编码,其他编码能转发但不能转协议 + - 支持H264/H265/AAC/G711/OPUS/MJPEG/MP3/VP8/VP9/AV1/MP2编码,其他编码能转发但不能转协议 - RTMP[S] - RTMP[S] 播放服务器,支持RTSP/MP4/HLS转RTMP @@ -70,25 +70,25 @@ - 支持H264/H265/AAC/G711/OPUS/MP3编码,其他编码能转发但不能转协议 - 支持[RTMP-H265](https://github.com/ksvc/FFmpeg/wiki) - 支持[RTMP-OPUS](https://github.com/ZLMediaKit/ZLMediaKit/wiki/RTMP%E5%AF%B9H265%E5%92%8COPUS%E7%9A%84%E6%94%AF%E6%8C%81) - - 支持[enhanced-rtmp(H265)](https://github.com/veovera/enhanced-rtmp) + - 支持[enhanced-rtmp(H265/VP8/VP9/AV1/OPUS)](https://github.com/veovera/enhanced-rtmp) - HLS - 支持HLS文件(mpegts/fmp4)生成,自带HTTP文件服务器 - 通过cookie追踪技术,可以模拟HLS播放为长连接,可以实现HLS按需拉流、播放统计等业务 - 支持HLS播发器,支持拉流HLS转rtsp/rtmp/mp4 - - 支持H264/H265/AAC/G711/OPUS/MP3编码 + - 支持H264/H265/AAC/G711/OPUS/MP3/VP8/VP9/AV1/MP2编码 - 支持多轨道模式 - TS - 支持http[s]-ts直播 - 支持ws[s]-ts直播 - - 支持H264/H265/AAC/G711/OPUS/MP3编码 + - 支持H264/H265/AAC/G711/OPUS/MP3/VP8/VP9/AV1/MP2编码 - 支持多轨道模式 - fMP4 - 支持http[s]-fmp4直播 - 支持ws[s]-fmp4直播 - - 支持H264/H265/AAC/G711/OPUS/MJPEG/MP3编码 + - 支持H264/H265/AAC/G711/OPUS/MJPEG/MP3/VP8/VP9/AV1/MP2编码 - 支持多轨道模式 - HTTP[S]与WebSocket @@ -103,7 +103,7 @@ - GB28181与RTP推流 - 支持UDP/TCP RTP(PS/TS/ES)推流服务器,可以转换成RTSP/RTMP/HLS等协议 - 支持RTSP/RTMP/HLS等协议转rtp推流客户端,支持TCP/UDP模式,提供相应restful api,支持主动被动方式 - - 支持H264/H265/AAC/G711/OPUS/MP3编码 + - 支持H264/H265/AAC/G711/OPUS/MP3/VP8/VP9/AV1编码 - 支持es/ps/ts/ehome rtp推流 - 支持es/ps rtp转推 - 支持GB28181主动拉流模式 @@ -113,7 +113,7 @@ - MP4点播与录制 - 支持录制为FLV/HLS/MP4 - RTSP/RTMP/HTTP-FLV/WS-FLV支持MP4文件点播,支持seek - - 支持H264/H265/AAC/G711/OPUS/MP3编码 + - 支持H264/H265/AAC/G711/OPUS/MP3/VP8/VP9/AV1编码 - 支持多轨道模式 - WebRTC @@ -131,11 +131,13 @@ - 支持webrtc over tcp模式 - 优秀的nack、jitter buffer算法, 抗丢包能力卓越 - 支持whip/whep协议 + - 支持编码格式与rtsp协议一致 + - [支持ice-full,支持作为webrtc客户端拉流、推流以及p2p模式](./webrtc/USAGE.md) + - [SRT支持](./srt/srt.md) - 其他 - 支持丰富的restful api以及web hook事件 - - 支持简单的telnet调试 - - 支持配置文件热加载 + - 支持配置文件、ssl证书热加载 - 支持流量统计、推拉流鉴权等事件 - 支持虚拟主机,可以隔离不同域名 - 支持按需拉流,无人观看自动关断拉流 @@ -146,7 +148,48 @@ - 支持按需解复用、转协议,当有人观看时才开启转协议,降低cpu占用率 - 支持溯源模式的集群部署,溯源方式支持rtsp/rtmp/hls/http-ts, 边沿站支持hls, 源站支持多个(采用round robin方式溯源) - rtsp/rtmp/webrtc推流异常断开后,可以在超时时间内重连推流,播放器无感知 + +## 闭源专业版 +在最新开源代码的基础,新增以下[闭源专业版](https://github.com/xia-chu/zlmediakit-pro) +- 音视频转码功能 + - 1、音视频间任意转码(包括h265/h264/opus/g711/aac/g722/g722.1/mp3/svac/vp8/vp9/av1等。 + - 2、基于配置文件的转码,支持设置比特率,codec类型等参数。 + - 3、基于http api的动态增减转码,支持设置比特率,分辨率倍数,codec类型、滤镜等参数。 + - 4、支持硬件、软件自适应转码。 + - 5、支持按需转码,有人观看才转码,支持透明转码模式,业务无需感知转码的存在,业务代码无需做任何调整。 + - 6、支持负载过高时,转码主动降低帧率且不花屏。 + - 7、支持滤镜,支持添加osd文本以及logo角标等能力。 + - 8、支持全GPU硬件编解码与滤镜,防止显存与内存频繁拷贝。 + +- JT1078部标版本 + - 1、支持接收jt1078推流转其他协议;自适应音视频共享seq和单独seq模式。 + - 2、支持jt1078级联,支持jt1078对讲。 + - 3、jt1078相关接口、端口和用法与GB28181用法一致,保持兼容。 + - 4、支持h264/h265/g711/aac/mp3/g721/g722/g723/g729/g726/adpcm等编码。 + +- IPTV版本 + - 1、支持rtsp-ts/hls/http-ts/rtp组播/udp组播拉流转协议,支持ts透传模式,无需解复用转rtsp-ts/hls/http-ts/srt协议。 + - 2、支持接收rtsp-ts/srt/rtp-ts推流,支持ts透传模式,无需解复用转rtsp-ts/hls/http-ts/srt协议。 + - 3、上述功能同时支持解复用ts为es流再转rtsp/rtmp/flv/http-ts/hls/hls-fmp4/mp4/fmp4/webrtc等协议。 + +- S3云存储 + - 支持s3/minio云存储内存流直接写入,解决录像文件io系统瓶颈问题 + - 支持直接通过zlmediakit的http服务下载和点播云存储文件。 + - 支持遍历云存储文件并生成http菜单网页。 + +- WebRTC集群 + - 支持rtc流量代理,解决k8s部署zlmediakit webrtc服务时,http信令交互与rtc流量打不到同一个pod实例的问题。 + +- AI推理 + - 支持yolo推理插件,支持人员、车辆等目标AI识别,支持目标跟踪,支持多边形布防,支持ocr,支持c++/python插件快速混合开发。 + - 支持tensorRT 全cuda加速推理。 + - 支持onnxruntime(cpu/gpu) 推理。 + - 支持ascend cann加速推理。 + - python插件支持调用c++接口操作流媒体与绘制当前视频画面。 +- WebRTC mcu语音聊天室 + - 支持mcu多人语音聊天室,混音前支持背景噪声消除,静音不参与混音,解决超大规模多人语音聊天室sfu方案不可用的问题。 + - 支持100人语音连麦,上千人旁听级会议。 ## 编译以及测试 **编译前务必仔细参考wiki:[快速开始](https://github.com/ZLMediaKit/ZLMediaKit/wiki/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B)操作!!!** @@ -191,17 +234,22 @@ bash build_docker_images.sh - [jessibuca](https://github.com/langhuihui/jessibuca) 基于wasm支持H265的播放器 - [wsPlayer](https://github.com/v354412101/wsPlayer) 基于MSE的websocket-fmp4播放器 - [BXC_gb28181Player](https://github.com/any12345com/BXC_gb28181Player) C++开发的支持国标GB28181协议的视频流播放器 - - [RTCPlayer](https://github.com/leo94666/RTCPlayer) 一个基于Android客户端的的RTC播放器 + - [WebRTC-Vue-Demo](https://github.com/Heartbreaker16/ZLMediaKit-WebRTC-Vue-Demo) zlmediakit webrtc播放器vue版本 - WEB管理网站 - [zlm_webassist](https://github.com/1002victor/zlm_webassist) 本项目配套的前后端分离web管理项目 - [AKStreamNVR](https://github.com/langmansh/AKStreamNVR) 前后端分离web项目,支持webrtc播放 + - [StreamUI](https://github.com/lmk123568/StreamUI) 一个极简、轻便的视频流媒体管理平台 + - [PyMKUI](https://github.com/ZLMediaKit/pymkui) ZLMediaKit官方推出的管理平台网站 - SDK - [spring-boot-starter](https://github.com/lunasaw/zlm-spring-boot-starter) 本项目hook和rest接口starter - [java sdk](https://github.com/lidaofu-hub/j_zlm_sdk) 本项目c sdk完整java包装库 - [c# sdk](https://github.com/malegend/ZLMediaKit.Autogen) 本项目c sdk完整c#包装库 - [metaRTC](https://github.com/metartc/metaRTC) 全国产纯c webrtc sdk + +- 监控与运维 + - [ZLMediaKit_exporter](https://github.com/guohuachan/ZLMediaKit_exporter) 一个用于采集 ZLMediaKit 核心指标的 Prometheus Exporter,搭配 Grafana 即可快速构建实时监控面板 - 其他项目(已停止更新) - [NodeJS实现的GB28181平台](https://gitee.com/hfwudao/GB28181_Node_Http) @@ -382,6 +430,9 @@ bash build_docker_images.sh [ss002012](https://github.com/ss002012) [a839419160](https://github.com/a839419160) [oldma3095](https://github.com/oldma3095) +[Dary](https://github.com/watersounds) +[N.z](https://github.com/neesonqk) +[yanggs](https://github.com/callinglove) 同时感谢JetBrains对开源项目的支持,本项目使用CLion开发与调试: diff --git a/README_en.md b/README_en.md index 1065e5b2..d9b68db4 100644 --- a/README_en.md +++ b/README_en.md @@ -45,7 +45,7 @@ ## Feature List ### Overview of Features -Overview of Features +Overview of Features - RTSP[S] - RTSP[S] server, supports RTMP/MP4/HLS to RTSP[S] conversion, supports devices such as Amazon Echo Show @@ -124,6 +124,8 @@ - Supports WebRTC over TCP mode - Excellent NACK and jitter buffer algorithms with outstanding packet loss resistance - Supports WHIP/WHEP protocols + - [Supports ice-full, works as a WebRTC client for pulling streams, pushing streams, and P2P mode](./webrtc/USAGE.md) + - [SRT support](./srt/srt.md) - Others - Supports rich RESTful APIs and webhook events @@ -139,7 +141,36 @@ - Supports on-demand demultiplexing and protocol conversion, reducing CPU usage by only enabling it when someone is watching - Supports cluster deployment in traceable mode, with RTSP/RTMP/HLS/HTTP-TS support for traceable mode and HLS support for edge stations and multiple sources for source stations (using round-robin tracing) - Can reconnect to streaming after abnormal disconnection in RTSP/RTMP/WebRTC pushing within a timeout period, with no impact on the player. - + +## Closed-Source Professional Edition +Based on the latest open-source code, the following closed-source professional editions have been added. For details, please contact: 1213642868@qq.com + +- Transcoding Version + - Supports arbitrary audio and video transcoding, including H.265/H.264/Opus/G.711/AAC/G.722/G.722.1/MP3/SVAC, etc. + - Configuration file-based transcoding, allowing customization of bitrate, codec type, and other parameters. + - Dynamic transcoding management via HTTP API, supporting settings for bitrate, resolution scaling, codec type, filters, etc. + - Supports adaptive hardware and software transcoding. + - Supports on-demand transcoding, only transcoding when a viewer is present. It also supports transparent transcoding mode, requiring no modifications to business logic. + - Supports automatic frame rate reduction under high load conditions to prevent video artifacts. + - Supports filters, including OSD text overlay and logo watermarking. + - Supports full GPU hardware encoding/decoding and filtering, minimizing frequent memory transfers between VRAM and RAM. + - Supports full GPU (CUDA) inference plugins, enabling AI-based object detection for people, vehicles, and other targets. + +- JT1078 Version + - Supports JT1078 stream ingestion and protocol conversion, with adaptive audio-video shared sequence and individual sequence modes. + - Adds JT1078 cascading support and JT1078 intercom support. + - JT1078 APIs and usage remain consistent with GB28181, ensuring compatibility. + - Supports H.264/H.265/G.711/AAC/MP3/G.721/G.722/G.723/G.729/G.726/ADPCM encoding. + +- IPTV Version + - Supports RTSP-TS/HLS/HTTP-TS/RTP multicast/UDP multicast stream ingestion and protocol conversion. Supports TS passthrough mode, eliminating the need for demuxing when converting to RTSP-TS/HLS/HTTP-TS/SRT. + - Supports RTSP-TS/SRT stream ingestion and TS passthrough mode, avoiding the need for demuxing when converting to RTSP-TS/HLS/HTTP-TS/SRT. + - All the above features also support demuxing TS into ES streams and converting them to RTSP/RTMP/FLV/HTTP-TS/HLS/HLS-FMP4/MP4/FMP4/WebRTC. + +- VP9/AV1 Version + Fully supports AV1/VP9 encoding, with RTMP/RTSP/TS/PS/HLS/MP4/FMP4 protocol compatibility for AV1/VP9. + + ## System Requirements - Compiler with c++11 support, such as GCC 4.8+, Clang 3.3+, or VC2015+. @@ -375,6 +406,8 @@ bash build_docker_images.sh - [GB28181 player implemented in C++](https://github.com/any12345com/BXC_gb28181Player) - [Android RTCPlayer](https://github.com/leo94666/RTCPlayer) +- Monitor + - [Prometheus Exporter for ZLMediaKit](https://github.com/guohuachan/ZLMediaKit_exporter) ## License @@ -542,6 +575,9 @@ Thanks to all those who have supported this project in various ways, including b [ss002012](https://github.com/ss002012) [a839419160](https://github.com/a839419160) [oldma3095](https://github.com/oldma3095) +[Dary](https://github.com/watersounds) +[N.z](https://github.com/neesonqk) +[yanggs](https://github.com/callinglove) Also thank to JetBrains for their support for open source project, we developed and debugged zlmediakit with CLion: diff --git a/api/CMakeLists.txt b/api/CMakeLists.txt index 102a03a6..ca622a6d 100644 --- a/api/CMakeLists.txt +++ b/api/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -77,6 +77,31 @@ install(TARGETS mk_api LIBRARY DESTINATION ${INSTALL_PATH_LIB} RUNTIME DESTINATION ${INSTALL_PATH_RUNTIME}) +if(MSVC) + set(RESOURCE_FILE "${CMAKE_SOURCE_DIR}/resource.rc") + set_source_files_properties(${RESOURCE_FILE} PROPERTIES LANGUAGE RC) + target_sources(mk_api PRIVATE ${RESOURCE_FILE}) +endif() + +#relase 类型时额外输出debug调试信息 +string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER) +if(UNIX AND ENABLE_OBJCOPY) + if("${CMAKE_BUILD_TYPE_LOWER}" STREQUAL "release") + find_program(OBJCOPY_FOUND objcopy) + if (OBJCOPY_FOUND) + add_custom_command(TARGET mk_api + POST_BUILD + COMMAND objcopy --only-keep-debug ${EXECUTABLE_OUTPUT_PATH}/libmk_api.so ${EXECUTABLE_OUTPUT_PATH}/libmk_api.so.debug + COMMAND objcopy --strip-all ${EXECUTABLE_OUTPUT_PATH}/libmk_api.so + COMMAND objcopy --add-gnu-debuglink=${EXECUTABLE_OUTPUT_PATH}/libmk_api.so.debug ${EXECUTABLE_OUTPUT_PATH}/libmk_api.so + ) + install(FILES ${EXECUTABLE_OUTPUT_PATH}/libmk_api.so.debug DESTINATION ${INSTALL_PATH_RUNTIME}) + else() + message(STATUS "not found objcopy, generate libmk_api.so.debug skip") + endif() + endif() +endif() + # IOS 跳过测试代码 if(IOS) return() diff --git a/api/include/mk_common.h b/api/include/mk_common.h index 5d01e8e0..78233ba3 100755 --- a/api/include/mk_common.h +++ b/api/include/mk_common.h @@ -259,31 +259,24 @@ API_EXPORT uint16_t API_CALL mk_rtp_server_start(uint16_t port); */ API_EXPORT uint16_t API_CALL mk_rtc_server_start(uint16_t port); -// 获取webrtc answer sdp回调函数 [AUTO-TRANSLATED:10c93fa9] -// Get webrtc answer sdp callback function -typedef void(API_CALL *on_mk_webrtc_get_answer_sdp)(void *user_data, const char *answer, const char *err); /** - * webrtc交换sdp,根据offer sdp生成answer sdp - * @param user_data 回调用户指针 - * @param cb 回调函数 - * @param type webrtc插件类型,支持echo,play,push - * @param offer webrtc offer sdp - * @param url rtc url, 例如 rtc://__defaultVhost/app/stream?key1=val1&key2=val2 - * webrtc exchange sdp, generate answer sdp based on offer sdp - * @param user_data Callback user pointer - * @param cb Callback function - * @param type webrtc plugin type, supports echo, play, push - * @param offer webrtc offer sdp - * @param url rtc url, for example rtc://__defaultVhost/app/stream?key1=val1&key2=val2 - - * [AUTO-TRANSLATED:ea79659b] + * 创建websocket[s]信令服务器 + * @param port websocket监听端口 + * @param ssl 是否为ssl类型服务器 + * @return 0:失败,非0:端口号 + * */ -API_EXPORT void API_CALL mk_webrtc_get_answer_sdp(void *user_data, on_mk_webrtc_get_answer_sdp cb, const char *type, - const char *offer, const char *url); +API_EXPORT uint16_t API_CALL mk_signaling_server_start(uint16_t port, int ssl); + +/** + * 创建webrtc-ice[s]服务器 + * @param port websocket监听端口 + * @return 0:失败,非0:端口号 + * + */ +API_EXPORT uint16_t API_CALL mk_ice_server_start(uint16_t port); -API_EXPORT void API_CALL mk_webrtc_get_answer_sdp2(void *user_data, on_user_data_free user_data_free, on_mk_webrtc_get_answer_sdp cb, const char *type, - const char *offer, const char *url); /** * 创建srt服务器 diff --git a/api/include/mk_events_objects.h b/api/include/mk_events_objects.h index d4f2acfd..7179c728 100644 --- a/api/include/mk_events_objects.h +++ b/api/include/mk_events_objects.h @@ -193,6 +193,8 @@ API_EXPORT uint64_t API_CALL mk_media_source_get_alive_second(const mk_media_sou API_EXPORT int API_CALL mk_media_source_close(const mk_media_source ctx,int force); //MediaSource::seekTo() API_EXPORT int API_CALL mk_media_source_seek_to(const mk_media_source ctx,uint32_t stamp); +// MediaSource::setSpeed() +API_EXPORT void API_CALL mk_media_source_set_speed(const mk_media_source ctx, float speed); /** * rtp推流成功与否的回调(第一次成功后,后面将一直重试) diff --git a/api/include/mk_frame.h b/api/include/mk_frame.h index 68d06ce2..d5da969c 100644 --- a/api/include/mk_frame.h +++ b/api/include/mk_frame.h @@ -343,6 +343,40 @@ API_EXPORT void API_CALL mk_mpeg_muxer_init_complete(mk_mpeg_muxer ctx); */ API_EXPORT int API_CALL mk_mpeg_muxer_input_frame(mk_mpeg_muxer ctx, mk_frame frame); +////////////////////////////////////////////////////////////////////// +#if defined(ENABLE_RTPPROXY) + +typedef struct mk_ps_decoder_t *mk_ps_decoder; + +typedef void (API_CALL *on_mk_ps_decoder_stream)(void *user_data, int stream, int codecid, const void *ext, size_t ext_len, int finish); +typedef void(API_CALL *on_mk_ps_decoder_frame)(void *user_data, int stream, int codecid, int flags, int64_t pts, int64_t dts, const void *data, size_t bytes); + +/** + * 创建一个ps解析器 + * @param scb stream 回调; 可选, 如果明确知道数据类型也许不需要此回调创建track? + * @param dcb 数据回调;必填 + * @param user_data 用户自定义数据 + * @return + */ +API_EXPORT mk_ps_decoder API_CALL mk_ps_decoder_create(on_mk_ps_decoder_stream scb, on_mk_ps_decoder_frame dcb, void * user_data); + +/** + * 释放ps解析器 + * @param ctx + */ +API_EXPORT void API_CALL mk_ps_decoder_release(mk_ps_decoder ctx); + +/** + * 输入ps数据 + * @param ctx ps解析器指针 + * @param data ps数据指针 + * @param bytes 数据长度 + */ +API_EXPORT void API_CALL mk_ps_decoder_input(mk_ps_decoder ctx, const char * data, size_t bytes); + + +# endif + #ifdef __cplusplus } #endif diff --git a/api/include/mk_mediakit.h b/api/include/mk_mediakit.h index b7682195..2fba3812 100755 --- a/api/include/mk_mediakit.h +++ b/api/include/mk_mediakit.h @@ -27,5 +27,6 @@ #include "mk_frame.h" #include "mk_track.h" #include "mk_transcode.h" +#include "mk_webrtc.h" #endif /* MK_API_H_ */ diff --git a/api/include/mk_recorder.h b/api/include/mk_recorder.h index e70213e8..9228c923 100644 --- a/api/include/mk_recorder.h +++ b/api/include/mk_recorder.h @@ -125,6 +125,21 @@ API_EXPORT int API_CALL mk_recorder_start(int type, const char *vhost, const cha */ API_EXPORT int API_CALL mk_recorder_stop(int type, const char *vhost, const char *app, const char *stream); + + +/** + * 开始事件视频录制 + * @param vhost 虚拟主机 + * @param app 应用名 + * @param stream 流id + * @param path 录像文件保存相对路径,包括名称 + * @param back_ms 回溯录制时长 + * @param forward_ms 后续录制时长 + * @return 1:成功,0:失败 + * */ +API_EXPORT int API_CALL mk_recorder_start_task(const char *vhost, const char *app, const char *stream, const char *path, uint32_t back_ms, uint32_t forward_ms); + + /** * 加载mp4列表 * @param vhost 虚拟主机 diff --git a/api/include/mk_rtp_server.h b/api/include/mk_rtp_server.h index 1b837bec..b85e9831 100644 --- a/api/include/mk_rtp_server.h +++ b/api/include/mk_rtp_server.h @@ -21,6 +21,7 @@ typedef struct mk_rtp_server_t *mk_rtp_server; * @param port 监听端口,0则为随机 * @param tcp_mode tcp模式(0: 不监听端口 1: 监听端口 2: 主动连接到服务端) * @param stream_id 该端口绑定的流id + * @param multiple 多路复用RTP服务器 1: 开启 0: 不开启 * @return * Create GB28181 RTP server * @param port Listening port, 0 for random @@ -32,6 +33,7 @@ typedef struct mk_rtp_server_t *mk_rtp_server; */ API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create(uint16_t port, int tcp_mode, const char *stream_id); API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create2(uint16_t port, int tcp_mode, const char *vhost, const char *app, const char *stream_id); +API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create3(uint16_t port, int tcp_mode, const char *vhost, const char *app, const char *stream_id, int multiplex); /** * TCP 主动模式时连接到服务器是否成功的回调 @@ -110,6 +112,53 @@ typedef void(API_CALL *on_mk_rtp_server_detach)(void *user_data); API_EXPORT void API_CALL mk_rtp_server_set_on_detach(mk_rtp_server ctx, on_mk_rtp_server_detach cb, void *user_data); API_EXPORT void API_CALL mk_rtp_server_set_on_detach2(mk_rtp_server ctx, on_mk_rtp_server_detach cb, void *user_data, on_user_data_free user_data_free); +/** + * 更新RTP服务器过滤SSRC + * @param ctx 服务器对象 + * @param ssrc 十进制ssrc + * + */ +API_EXPORT void API_CALL mk_rtp_server_update_ssrc(mk_rtp_server ctx, uint32_t ssrc); + + +/** + * rtp信息获取回调 + * @param exist 存在rtp信息 0: 不存在 1: 存在 + * @param peer_ip 连接ip + * @param peer_port 连接端口 + * @param local_ip 本地ip + * @param local_port 本地端口 + * @param identifier 身份信息 + * + */ +typedef void(API_CALL *on_mk_rtp_get_info)(int exist, const char *peer_ip, uint16_t peer_port, const char *local_ip, uint16_t local_port, const char *identifier); + +/** + * 获取rtp推流信息 + * @param app 应用名 + * @param stream 流id + * @param cb rtp信息获取回调 + * + */ +API_EXPORT void API_CALL mk_rtp_get_info(const char *app, const char *stream, on_mk_rtp_get_info cb); + + +/** + * 暂停RTP超时检查 + * @param app 应用名 + * @param stream 流id + * + */ +API_EXPORT void API_CALL mk_rtp_pause_check(const char *app, const char *stream); + +/** + * 恢复RTP超时检查 + * @param app 应用名 + * @param stream 流id + * + */ +API_EXPORT void API_CALL mk_rtp_resume_check(const char *app, const char *stream); + #ifdef __cplusplus } #endif \ No newline at end of file diff --git a/api/include/mk_webrtc.h b/api/include/mk_webrtc.h new file mode 100644 index 00000000..3844e391 --- /dev/null +++ b/api/include/mk_webrtc.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef MK_WEBRTC_H +#define MK_WEBRTC_H +#include "mk_common.h" +#include "mk_proxyplayer.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + + +// 获取webrtc answer sdp回调函数 [AUTO-TRANSLATED:10c93fa9] +// Get webrtc answer sdp callback function +typedef void(API_CALL *on_mk_webrtc_get_answer_sdp)(void *user_data, const char *answer, const char *err); + +// 获取webrtc proxy player信息回调函数 +typedef void(API_CALL *on_mk_webrtc_get_proxy_player_info_cb)(const char *info_json, const char *err); + +//WebRTC-注册到信令服务器、WebRTC-从信令服务器注销回调函数 +typedef void(API_CALL *on_mk_webrtc_room_keeper_info_cb)(void *user_data, const char *room_key, const char *err); + +//获取WebRTC-Peer查看注册信息、WebRTC-信令服务器查看注册信息回调函数 +typedef void(API_CALL *on_mk_webrtc_room_keeper_data_cb)(const char *data); + + +/** + * webrtc交换sdp,根据offer sdp生成answer sdp + * @param user_data 回调用户指针 + * @param cb 回调函数 + * @param type webrtc插件类型,支持echo,play,push + * @param offer webrtc offer sdp + * @param url rtc url, 例如 rtc://__defaultVhost/app/stream?key1=val1&key2=val2 + * webrtc exchange sdp, generate answer sdp based on offer sdp + * @param user_data Callback user pointer + * @param cb Callback function + * @param type webrtc plugin type, supports echo, play, push + * @param offer webrtc offer sdp + * @param url rtc url, for example rtc://__defaultVhost/app/stream?key1=val1&key2=val2 + + * [AUTO-TRANSLATED:ea79659b] + */ +API_EXPORT void API_CALL mk_webrtc_get_answer_sdp(void *user_data, on_mk_webrtc_get_answer_sdp cb, const char *type, const char *offer, const char *url); + +API_EXPORT void API_CALL mk_webrtc_get_answer_sdp2( + void *user_data, on_user_data_free user_data_free, on_mk_webrtc_get_answer_sdp cb, const char *type, const char *offer, const char *url); + +/** + * 获取webrtc proxy player信息 + * @param mk_proxy_player 代理 + * @param cb 回调函数 + */ +API_EXPORT void API_CALL mk_webrtc_get_proxy_player_info(mk_proxy_player ctx, on_mk_webrtc_get_proxy_player_info_cb cb); + + +/** + * WebRTC-注册到信令服务器 + * @param server_host 信令服务器host + * @param server_port 信令服务器port + * @param room_id 房间id + * @param ssl 是否启用ssl + * @param cb 回调函数 + * @param user_data 用户数据 + */ +API_EXPORT void API_CALL +mk_webrtc_add_room_keeper(const char *room_id, const char *server_host, uint16_t server_port, int ssl, on_mk_webrtc_room_keeper_info_cb cb, void *user_data); + + +API_EXPORT void API_CALL mk_webrtc_add_room_keeper2( + const char *room_id, const char *server_host, uint16_t server_port, int ssl, on_mk_webrtc_room_keeper_info_cb cb, void *user_data, + on_user_data_free user_data_free); + + +/** + * WebRTC-从信令服务器注销 + * @param room_key 房间key + * @param cb 回调函数 + * @param user_data 用户数据 + */ +API_EXPORT void API_CALL mk_webrtc_del_room_keeper(const char *room_key, on_mk_webrtc_room_keeper_info_cb cb, void *user_data); + +API_EXPORT void API_CALL +mk_webrtc_del_room_keeper2(const char *room_key, on_mk_webrtc_room_keeper_info_cb cb, void *user_data, on_user_data_free user_data_free); + + +/** + * WebRTC-Peer查看注册信息 + * @param cb 回调函数 + */ +API_EXPORT void API_CALL mk_webrtc_list_room_keeper(on_mk_webrtc_room_keeper_data_cb cb); + +/** + * WebRTC-信令服务器查看注册信息 + * @param cb 回调函数 + */ +API_EXPORT void API_CALL mk_webrtc_list_rooms(on_mk_webrtc_room_keeper_data_cb cb); + +#ifdef __cplusplus +} +#endif + +#endif /* MK_WEBRTC_H */ \ No newline at end of file diff --git a/api/source/mk_common.cpp b/api/source/mk_common.cpp index 3bc274da..3e3f5ecc 100644 --- a/api/source/mk_common.cpp +++ b/api/source/mk_common.cpp @@ -29,6 +29,7 @@ using namespace mediakit; static TcpServer::Ptr rtsp_server[2]; static TcpServer::Ptr rtmp_server[2]; static TcpServer::Ptr http_server[2]; +static TcpServer::Ptr signaling_server[2]; static TcpServer::Ptr shell_server; #ifdef ENABLE_RTPPROXY @@ -37,9 +38,14 @@ static RtpServer::Ptr rtpServer; #endif #ifdef ENABLE_WEBRTC -#include "../webrtc/WebRtcSession.h" +#include "webrtc/WebRtcSession.h" +#include "webrtc/IceSession.hpp" +#include "webrtc/WebRtcSignalingSession.h" +#include "webrtc/WebRtcTransport.h" static UdpServer::Ptr rtcServer_udp; static TcpServer::Ptr rtcServer_tcp; +static UdpServer::Ptr iceServer_udp; +static TcpServer::Ptr iceServer_tcp; #endif #if defined(ENABLE_SRT) @@ -76,6 +82,9 @@ API_EXPORT void API_CALL mk_stop_all_server(){ #ifdef ENABLE_WEBRTC rtcServer_udp = nullptr; rtcServer_tcp = nullptr; + iceServer_udp = nullptr; + iceServer_tcp = nullptr; + CLEAR_ARR(signaling_server); #endif #ifdef ENABLE_SRT srtServer = nullptr; @@ -288,46 +297,46 @@ API_EXPORT uint16_t API_CALL mk_rtc_server_start(uint16_t port) { #endif } -#ifdef ENABLE_WEBRTC -class WebRtcArgsUrl : public mediakit::WebRtcArgs { -public: - WebRtcArgsUrl(std::string url) { _url = std::move(url); } - toolkit::variant operator[](const std::string &key) const override { - if (key == "url") { - return _url; +API_EXPORT uint16_t API_CALL mk_signaling_server_start(uint16_t port, int ssl) { +#ifdef ENABLE_WEBRTC + ssl = MAX(0, MIN(ssl, 1)); + try { + signaling_server[ssl] = std::make_shared(); + if (ssl) { + signaling_server[ssl]->start(port); + } else { + signaling_server[ssl]->start(port); } - return ""; + return signaling_server[ssl]->getPort(); + } catch (std::exception &ex) { + signaling_server[ssl] = nullptr; + WarnL << ex.what(); + return 0; } - -private: - std::string _url; -}; -#endif - -API_EXPORT void API_CALL mk_webrtc_get_answer_sdp(void *user_data, on_mk_webrtc_get_answer_sdp cb, const char *type, - const char *offer, const char *url) { - mk_webrtc_get_answer_sdp2(user_data, nullptr, cb, type, offer, url); -} -API_EXPORT void API_CALL mk_webrtc_get_answer_sdp2(void *user_data, on_user_data_free user_data_free, on_mk_webrtc_get_answer_sdp cb, const char *type, - const char *offer, const char *url) { -#ifdef ENABLE_WEBRTC - assert(type && offer && url && cb); - auto session = std::make_shared(Socket::createSocket()); - std::string offer_str = offer; - std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); - auto args = std::make_shared(url); - WebRtcPluginManager::Instance().negotiateSdp(*session, type, *args, [offer_str, session, ptr, cb](const WebRtcInterface &exchanger) mutable { - auto &handler = const_cast(exchanger); - try { - auto sdp_answer = handler.getAnswerSdp(offer_str); - cb(ptr.get(), sdp_answer.data(), nullptr); - } catch (std::exception &ex) { - cb(ptr.get(), nullptr, ex.what()); - } - }); #else WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; + return 0; +#endif +} + +API_EXPORT uint16_t API_CALL mk_ice_server_start(uint16_t port){ +#ifdef ENABLE_WEBRTC + try { + iceServer_tcp = std::make_shared(); + iceServer_udp = std::make_shared(); + iceServer_udp->start(port); + iceServer_tcp->start(port); + return 0; + } catch (std::exception &ex) { + iceServer_udp = nullptr; + iceServer_tcp = nullptr; + WarnL << ex.what(); + return 0; + } +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; + return 0; #endif } diff --git a/api/source/mk_events_objects.cpp b/api/source/mk_events_objects.cpp index 908a912b..50d040bb 100644 --- a/api/source/mk_events_objects.cpp +++ b/api/source/mk_events_objects.cpp @@ -296,6 +296,13 @@ API_EXPORT int API_CALL mk_media_source_seek_to(const mk_media_source ctx,uint32 MediaSource *src = (MediaSource *)ctx; return src->seekTo(stamp); } + +API_EXPORT void API_CALL mk_media_source_set_speed(const mk_media_source ctx, float speed) { + assert(ctx); + MediaSource *src = (MediaSource *)ctx; + src->getOwnerPoller()->async([=]() mutable { src->speed(speed); }); +} + API_EXPORT void API_CALL mk_media_source_start_send_rtp(const mk_media_source ctx, const char *dst_url, uint16_t dst_port, const char *ssrc, int con_type, on_mk_media_source_send_rtp_result cb, void *user_data) { mk_media_source_start_send_rtp2(ctx, dst_url, dst_port, ssrc, con_type, cb, user_data, nullptr); } @@ -347,6 +354,7 @@ API_EXPORT void API_CALL mk_media_source_start_send_rtp4(const mk_media_source c args.close_delay_ms = (*ini_ptr)["close_delay_ms"].empty() ? 0 : (*ini_ptr)["close_delay_ms"].as(); args.rtcp_timeout_ms = (*ini_ptr)["rtcp_timeout_ms"].empty() ? 30000 : (*ini_ptr)["rtcp_timeout_ms"].as(); args.rtcp_send_interval_ms = (*ini_ptr)["rtcp_send_interval_ms"].empty() ? 5000 : (*ini_ptr)["rtcp_send_interval_ms"].as(); + args.enable_origin_recv_limit = (*ini_ptr)["enable_origin_recv_limit"].empty() ? false : (*ini_ptr)["enable_origin_recv_limit"].as(); std::shared_ptr ptr( user_data, user_data_free ? user_data_free : [](void *) {}); src->getOwnerPoller()->async([=]() mutable { diff --git a/api/source/mk_frame.cpp b/api/source/mk_frame.cpp index c84453d3..94d5cb50 100644 --- a/api/source/mk_frame.cpp +++ b/api/source/mk_frame.cpp @@ -11,6 +11,7 @@ #include "mk_frame.h" #include "Record/MPEG.h" #include "Extension/Factory.h" +#include "Rtp/PSDecoder.h" using namespace mediakit; @@ -223,4 +224,36 @@ API_EXPORT int API_CALL mk_mpeg_muxer_input_frame(mk_mpeg_muxer ctx, mk_frame fr assert(ctx && frame); auto ptr = reinterpret_cast(ctx); return ptr->inputFrame(*((Frame::Ptr *) frame)); -} \ No newline at end of file +} + + +////////////////////////////////////////////////////////////////////// +#if defined(ENABLE_RTPPROXY) + +API_EXPORT mk_ps_decoder API_CALL mk_ps_decoder_create(on_mk_ps_decoder_stream scb, on_mk_ps_decoder_frame dcb, void * user_data) { + assert(dcb); + auto ps_decoder = new PSDecoder(); + std::shared_ptr ptr(user_data, [](void *) {}); + if (scb) { + ps_decoder->setOnStream([ptr,scb](int stream, int codecid, const void *extra, size_t bytes, int finish) { + scb(ptr.get(), stream, getCodecByMpegId(codecid), extra, bytes, finish); + }); + } + ps_decoder->setOnDecode([ptr,dcb](int stream, int codecid, int flags, int64_t pts, int64_t dts, const void *data, size_t bytes) { + dcb(ptr.get(), stream,getCodecByMpegId(codecid),flags,pts,dts,data,bytes); + }); + return reinterpret_cast(ps_decoder); +} + +API_EXPORT void API_CALL mk_ps_decoder_release(mk_ps_decoder ctx) { + assert(ctx); + auto ptr = reinterpret_cast(ctx); + delete ptr; +} + +API_EXPORT void API_CALL mk_ps_decoder_input(mk_ps_decoder ctx, const char * data, size_t bytes) { + assert(ctx && data); + auto ptr = reinterpret_cast(ctx); + ptr->input(reinterpret_cast(data), bytes); +} +#endif \ No newline at end of file diff --git a/api/source/mk_media.cpp b/api/source/mk_media.cpp index 65ae8e51..98127913 100755 --- a/api/source/mk_media.cpp +++ b/api/source/mk_media.cpp @@ -309,7 +309,7 @@ API_EXPORT void API_CALL mk_media_start_send_rtp2(mk_media ctx, const char *dst_ auto ref = *obj; std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); (*obj)->getChannel()->getOwnerPoller(MediaSource::NullMediaSource())->async([args, ref, cb, ptr]() { - ref->getChannel()->startSendRtp(MediaSource::NullMediaSource(), args, [cb, ptr](uint16_t local_port, const SockException &ex) { + ref->getChannel()->getMuxer(MediaSource::NullMediaSource())->startSendRtp( args, [cb, ptr](uint16_t local_port, const SockException &ex) { if (cb) { cb(ptr.get(), local_port, ex.getErrCode(), ex.what()); } @@ -343,13 +343,14 @@ API_EXPORT void API_CALL mk_media_start_send_rtp4(mk_media ctx, const char *dst_ args.close_delay_ms = (*ini_ptr)["close_delay_ms"].empty() ? 30000 : (*ini_ptr)["close_delay_ms"].as(); args.rtcp_timeout_ms = (*ini_ptr)["rtcp_timeout_ms"].empty() ? 30000 : (*ini_ptr)["rtcp_timeout_ms"].as(); args.rtcp_send_interval_ms = (*ini_ptr)["rtcp_send_interval_ms"].empty() ? 5000 : (*ini_ptr)["rtcp_send_interval_ms"].as(); + args.enable_origin_recv_limit = (*ini_ptr)["enable_origin_recv_limit"].empty() ? false : (*ini_ptr)["enable_origin_recv_limit"].as(); // sender参数无用 [AUTO-TRANSLATED:21590ae5] // The sender parameter is useless auto ref = *obj; std::shared_ptr ptr( user_data, user_data_free ? user_data_free : [](void *) {}); (*obj)->getChannel()->getOwnerPoller(MediaSource::NullMediaSource())->async([args, ref, cb, ptr]() { - ref->getChannel()->startSendRtp(MediaSource::NullMediaSource(), args, [cb, ptr](uint16_t local_port, const SockException &ex) { + ref->getChannel()->getMuxer(MediaSource::NullMediaSource())->startSendRtp(args, [cb, ptr](uint16_t local_port, const SockException &ex) { if (cb) { cb(ptr.get(), local_port, ex.getErrCode(), ex.what()); } @@ -365,7 +366,7 @@ API_EXPORT void API_CALL mk_media_stop_send_rtp(mk_media ctx, const char *ssrc) auto ref = *obj; string ssrc_str = ssrc ? ssrc : ""; (*obj)->getChannel()->getOwnerPoller(MediaSource::NullMediaSource())->async([ref, ssrc_str]() { - ref->getChannel()->stopSendRtp(MediaSource::NullMediaSource(), ssrc_str); + ref->getChannel()->getMuxer(MediaSource::NullMediaSource())->stopSendRtp(ssrc_str); }); } diff --git a/api/source/mk_recorder.cpp b/api/source/mk_recorder.cpp index d79f0654..7e5c9dfa 100644 --- a/api/source/mk_recorder.cpp +++ b/api/source/mk_recorder.cpp @@ -85,6 +85,27 @@ API_EXPORT int API_CALL mk_recorder_stop(int type, const char *vhost, const char return stopRecord((Recorder::type)type,vhost,app,stream); } +API_EXPORT int API_CALL mk_recorder_start_task(const char *vhost, const char *app, const char *stream, const char *path, uint32_t back_ms, uint32_t forward_ms) { + assert(vhost && app && stream); + auto src = MediaSource::find(vhost, app, stream); + if (!src) { + WarnL << "未找到相关的MediaSource,startRecordTask失败:" << vhost << "/" << app << "/" << stream; + return false; + } + bool ret; + src->getOwnerPoller()->async([=]() mutable { + std::string err; + try { + src->getMuxer()->startRecord(path, back_ms, forward_ms); + } catch (std::exception &ex) { + err = ex.what(); + WarnL << "MediaSource开启startRecordTask失败:" << vhost << "/" << app << "/" << stream << " what: " << err; + } + ret = err.empty(); + }); + return ret; +} + API_EXPORT void API_CALL mk_load_mp4_file(const char *vhost, const char *app, const char *stream, const char *file_path, int file_repeat) { mINI ini; mk_load_mp4_file2(vhost, app, stream, file_path, file_repeat, (mk_ini)&ini); diff --git a/api/source/mk_rtp_server.cpp b/api/source/mk_rtp_server.cpp index 96d478b0..4dd2f146 100644 --- a/api/source/mk_rtp_server.cpp +++ b/api/source/mk_rtp_server.cpp @@ -31,6 +31,13 @@ API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create2(uint16_t port, int tcp_m return (mk_rtp_server)server; } +API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create3(uint16_t port, int tcp_mode, const char *vhost, const char *app, const char *stream_id, int multiplex) { + RtpServer::Ptr *server = new RtpServer::Ptr(new RtpServer); + GET_CONFIG(std::string, local_ip, General::kListenIP) + (*server)->start(port, local_ip.c_str(), MediaTuple { vhost, app, stream_id, "" }, (RtpServer::TcpMode)tcp_mode,multiplex); + return (mk_rtp_server)server; +} + API_EXPORT void API_CALL mk_rtp_server_connect(mk_rtp_server ctx, const char *dst_url, uint16_t dst_port, on_mk_rtp_server_connected cb, void *user_data) { mk_rtp_server_connect2(ctx, dst_url, dst_port, cb, user_data, nullptr); } @@ -73,6 +80,41 @@ API_EXPORT void API_CALL mk_rtp_server_set_on_detach2(mk_rtp_server ctx, on_mk_r } } +API_EXPORT void API_CALL mk_rtp_server_update_ssrc(mk_rtp_server ctx, uint32_t ssrc) { + assert(ctx); + RtpServer::Ptr *server = (RtpServer::Ptr *)ctx; + (*server)->updateSSRC(ssrc); +} + + +API_EXPORT void API_CALL mk_rtp_get_info(const char *app, const char *stream, on_mk_rtp_get_info cb) { + assert(cb); + auto src = MediaSource::find(DEFAULT_VHOST, app, stream); + auto process = src ? src->getRtpProcess() : nullptr; + if (!process) { + cb(0, nullptr, 0, nullptr, 0, nullptr); + return; + } + SockInfo *info = process.get(); + cb(1, info->get_local_ip().c_str(), info->get_peer_port(), info->get_local_ip().c_str(), info->get_local_port(), info->getIdentifier().c_str()); +} + +API_EXPORT void API_CALL mk_rtp_pause_check(const char *app, const char *stream) { + auto src = MediaSource::find(DEFAULT_VHOST, app, stream); + auto process = src ? src->getRtpProcess() : nullptr; + if (process) { + process->pauseRtpTimeout(true); + } +} + +API_EXPORT void API_CALL mk_rtp_resume_check(const char *app, const char *stream) { + auto src = MediaSource::find(DEFAULT_VHOST, app, stream); + auto process = src ? src->getRtpProcess() : nullptr; + if (process) { + process->pauseRtpTimeout(false); + } +} + #else API_EXPORT mk_rtp_server API_CALL mk_rtp_server_create(uint16_t port, int enable_tcp, const char *stream_id) { diff --git a/api/source/mk_webrtc.cpp b/api/source/mk_webrtc.cpp new file mode 100644 index 00000000..2945bf16 --- /dev/null +++ b/api/source/mk_webrtc.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "mk_webrtc.h" +#include "mk_util.h" + +#include +#include +#include "Util/logger.h" +#include "Util/SSLBox.h" +#include "Util/File.h" +#include "Network/TcpServer.h" +#include "Network/UdpServer.h" +#include "Thread/WorkThreadPool.h" + +#include "Rtsp/RtspSession.h" +#include "Rtmp/RtmpSession.h" +#include "Http/HttpSession.h" +#include "Shell/ShellSession.h" +#include "Player/PlayerProxy.h" + +using namespace std; +using namespace toolkit; +using namespace mediakit; + +#ifdef ENABLE_WEBRTC + +#include "webrtc/WebRtcProxyPlayer.h" +#include "webrtc/WebRtcProxyPlayerImp.h" +#include "webrtc/WebRtcSignalingPeer.h" +#include "webrtc/WebRtcSignalingSession.h" +#include "webrtc/WebRtcSession.h" + +static UdpServer::Ptr rtcServer_udp; +static TcpServer::Ptr rtcServer_tcp; +class WebRtcArgsUrl : public mediakit::WebRtcArgs { +public: + WebRtcArgsUrl(std::string url) { _url = std::move(url); } + + toolkit::variant operator[](const std::string &key) const override { + if (key == "url") { + return _url; + } + return ""; + } + +private: + std::string _url; +}; +#endif + +API_EXPORT void API_CALL mk_webrtc_get_answer_sdp(void *user_data, on_mk_webrtc_get_answer_sdp cb, const char *type, const char *offer, const char *url) { + mk_webrtc_get_answer_sdp2(user_data, nullptr, cb, type, offer, url); +} +API_EXPORT void API_CALL mk_webrtc_get_answer_sdp2( + void *user_data, on_user_data_free user_data_free, on_mk_webrtc_get_answer_sdp cb, const char *type, const char *offer, const char *url) { +#ifdef ENABLE_WEBRTC + assert(type && offer && url && cb); + auto session = std::make_shared(Socket::createSocket()); + std::string offer_str = offer; + std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); + auto args = std::make_shared(url); + WebRtcPluginManager::Instance().negotiateSdp(*session, type, *args, [offer_str, session, ptr, cb](const WebRtcInterface &exchanger) mutable { + auto &handler = const_cast(exchanger); + try { + auto sdp_answer = handler.getAnswerSdp(offer_str); + cb(ptr.get(), sdp_answer.data(), nullptr); + } catch (std::exception &ex) { + cb(ptr.get(), nullptr, ex.what()); + } + }); +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; +#endif +} + +API_EXPORT void API_CALL mk_webrtc_get_proxy_player_info(mk_proxy_player ctx, on_mk_webrtc_get_proxy_player_info_cb cb) { +#ifdef ENABLE_WEBRTC + assert(ctx && cb); + PlayerProxy::Ptr *obj = (PlayerProxy::Ptr *)ctx; + auto media_player = obj->get()->getDelegate(); + if (!media_player) { + cb(nullptr, "Media player not found"); + return; + } + + auto webrtc_player_imp = std::dynamic_pointer_cast(media_player); + if (!webrtc_player_imp) { + cb(nullptr, "Stream proxy is not WebRTC type"); + return; + } + + auto webrtc_transport = webrtc_player_imp->getWebRtcTransport(); + if (!webrtc_transport) { + cb(nullptr, "WebRTC transport not available"); + return; + } + + webrtc_transport->getTransportInfo([cb](Json::Value transport_info) mutable { + if (transport_info.isMember("error")) { + cb(nullptr, strdup(transport_info["error"].asCString())); + return; + } + cb(strdup(transport_info.toStyledString().c_str()), ""); + }); +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; +#endif +} + +API_EXPORT void API_CALL mk_webrtc_add_room_keeper( + const char *room_id, const char *server_host, uint16_t server_port, int ssl, on_mk_webrtc_room_keeper_info_cb cb, void *user_data) { + mk_webrtc_add_room_keeper2(room_id, server_host, server_port, ssl, cb, user_data, nullptr); +} + +API_EXPORT void API_CALL mk_webrtc_add_room_keeper2( + const char *room_id, const char *server_host, uint16_t server_port, int ssl, on_mk_webrtc_room_keeper_info_cb cb, void *user_data, + on_user_data_free user_data_free) { +#ifdef ENABLE_WEBRTC + assert(server_host && server_port && room_id && cb); + // server_host: 信令服务器host + // server_post: 信令服务器host + // room_id: 注册的id,信令服务器会对该id进行唯一性检查 + std::string server_host_str(server_host), room_id_str(room_id); + std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); + addWebrtcRoomKeeper(server_host_str, server_port, room_id_str, ssl, [ptr,cb](const SockException &ex, const string &key) mutable { + if (ex) { + cb(ptr.get(), nullptr, ex.what()); + } else { + cb(ptr.get(), key.c_str(), nullptr); + } + }); +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; +#endif +} + +API_EXPORT void API_CALL mk_webrtc_del_room_keeper(const char *room_key, on_mk_webrtc_room_keeper_info_cb cb, void *user_data) { + mk_webrtc_del_room_keeper2(room_key,cb,user_data,nullptr); +} + +API_EXPORT void API_CALL +mk_webrtc_del_room_keeper2(const char *room_key, on_mk_webrtc_room_keeper_info_cb cb, void *user_data, on_user_data_free user_data_free) { +#ifdef ENABLE_WEBRTC + assert(room_key && cb); + std::string room_key_str(room_key); + std::shared_ptr ptr(user_data, user_data_free ? user_data_free : [](void *) {}); + delWebrtcRoomKeeper(room_key_str, [room_key_str, ptr, cb](const SockException &ex) mutable { + if (ex) { + cb(ptr.get(), room_key_str.c_str(), ex.what()); + } + cb(ptr.get(), room_key_str.c_str(), nullptr); + }); +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; +#endif +} + +API_EXPORT void API_CALL mk_webrtc_list_room_keeper(on_mk_webrtc_room_keeper_data_cb cb) { +#ifdef ENABLE_WEBRTC + assert(cb); + listWebrtcRoomKeepers([cb](const std::string &key, const WebRtcSignalingPeer::Ptr &p) { + Json::Value item = ToJson(p); + item["room_key"] = key; + cb(strdup(item.toStyledString().c_str())); + }); +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; +#endif +} + +API_EXPORT void API_CALL mk_webrtc_list_rooms(on_mk_webrtc_room_keeper_data_cb cb){ +#ifdef ENABLE_WEBRTC + assert(cb); + listWebrtcRooms([cb](const std::string &key, const WebRtcSignalingSession::Ptr &p) { + Json::Value item = ToJson(p); + item["room_id"] = key; + cb(strdup(item.toStyledString().c_str())); + }); +#else + WarnL << "未启用webrtc功能, 编译时请开启ENABLE_WEBRTC"; +#endif +} diff --git a/api/tests/CMakeLists.txt b/api/tests/CMakeLists.txt index 8108b71b..737ac6b6 100644 --- a/api/tests/CMakeLists.txt +++ b/api/tests/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/api/tests/pusher.c b/api/tests/pusher.c index bafbaaf5..e36b897d 100644 --- a/api/tests/pusher.c +++ b/api/tests/pusher.c @@ -64,7 +64,8 @@ void API_CALL on_mk_push_event_func(void *user_data,int err_code,const char *err void API_CALL on_mk_media_source_regist_func(void *user_data, mk_media_source sender, int regist){ Context *ctx = (Context *) user_data; const char *schema = mk_media_source_get_schema(sender); - if (strncmp(schema, ctx->push_url, strlen(schema)) == 0) { + if (strncmp(schema, ctx->push_url, strlen(schema)) == 0 || + (!strncmp(ctx->push_url, "webrtc", 5) && !strcmp(schema, "rtsp")) ) { // 判断是否为推流协议相关的流注册或注销事件 [AUTO-TRANSLATED:00a88a17] // Determine if it is a stream registration or deregistration event related to the streaming protocol release_pusher(&(ctx->pusher)); diff --git a/cmake/FindAVFILTER.cmake b/cmake/FindAVFILTER.cmake new file mode 100644 index 00000000..7b36f32a --- /dev/null +++ b/cmake/FindAVFILTER.cmake @@ -0,0 +1,16 @@ +find_path(AVFILTER_INCLUDE_DIR + NAMES libavfilter/avfilter.h + HINTS ${FFMPEG_PATH_ROOT} + PATH_SUFFIXES include) + +find_library(AVFILTER_LIBRARY + NAMES avfilter + HINTS ${FFMPEG_PATH_ROOT} + PATH_SUFFIXES bin lib) + +set(AVFILTER_LIBRARIES ${AVFILTER_LIBRARY}) +set(AVFILTER_INCLUDE_DIRS ${AVFILTER_INCLUDE_DIR}) + +include(FindPackageHandleStandardArgs) + +find_package_handle_standard_args(AVFILTER DEFAULT_MSG AVFILTER_LIBRARY AVFILTER_INCLUDE_DIR) diff --git a/cmake/FindTCMALLOC.cmake b/cmake/FindTCMALLOC.cmake new file mode 100644 index 00000000..5ae681aa --- /dev/null +++ b/cmake/FindTCMALLOC.cmake @@ -0,0 +1,16 @@ +find_path(Tcmalloc_INCLUDE_DIR + NAMES google/tcmalloc.h +) + +find_library(Tcmalloc_LIBRARY + NAMES tcmalloc_minimal tcmalloc +) + +set(TCMALLOC_LIBRARIES ${Tcmalloc_LIBRARY}) +set(TCMALLOC_INCLUDE_DIRS ${Tcmalloc_INCLUDE_DIR}) + +INCLUDE(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(TCMALLOC + DEFAULT_MSG + TCMALLOC_LIBRARIES TCMALLOC_INCLUDE_DIRS +) diff --git a/cmake/checks/atomic_check.cpp b/cmake/checks/atomic_check.cpp index f25e2d9f..283e61e8 100644 --- a/cmake/checks/atomic_check.cpp +++ b/cmake/checks/atomic_check.cpp @@ -1,4 +1,4 @@ -#include +#include static int test() { diff --git a/conf/config.ini b/conf/config.ini index a434d42a..7a1b3a0e 100644 --- a/conf/config.ini +++ b/conf/config.ini @@ -4,52 +4,93 @@ #!!!!你如果修改此范例配置文件(conf/config.ini),并不会被MediaServer进程加载,因为MediaServer进程默认加载的是release/${操作系统类型}/${编译类型}/config.ini。 #!!!!当然,你每次执行cmake,该文件确实会被拷贝至release/${操作系统类型}/${编译类型}/config.ini, #!!!!但是一般建议你直接修改release/${操作系统类型}/${编译类型}/config.ini文件,修改此文件一般不起作用,除非你运行MediaServer时使用-c参数指定到此文件。 +#!!!! This is a sample configuration file intended to explain the specific meanings and functions of each item. +#!!!! During the `cmake` execution, this file is copied to the `release/${OS type}/${build type}` directory. +#!!!! This directory is also the target path where the MediaServer executable runs and looks for `config.ini` by default. +#!!!! Modifying this sample file (`conf/config.ini`) will not affect the MediaServer process while it runs. +#!!!! Although executing `cmake` overwrites the target config file, it is highly recommended to modify `release/${OS type}/${build type}/config.ini` directly. +#!!!! Changes made here will only take effect if you explicitly load this file using the `-c` parameter when starting the MediaServer. [api] -#是否调试http api,启用调试后,会打印每次http请求的内容和回复 +# 是否调试http api,启用调试后,会打印每次http请求的内容和回复 +# Enable HTTP API debugging. When enabled, it logs the content and responses of each HTTP request. apiDebug=1 -#一些比较敏感的http api在访问时需要提供secret,否则无权限调用 -#如果是通过127.0.0.1访问,那么可以不提供secret + +# 一些比较敏感的http api在访问时需要提供secret,否则无权限调用 +# 如果是通过127.0.0.1访问,那么可以不提供secret +# For some sensitive HTTP APIs, a secret must be provided when accessing them, otherwise the call is unauthorized. +# If accessed via 127.0.0.1, the secret does not need to be provided. secret=035c73f7-bb6b-4889-a715-d9eb2d1925cc -#截图保存路径根目录,截图通过http api(/index/api/getSnap)生成和获取 + +# 截图保存路径根目录,截图通过http api(/index/api/getSnap)生成和获取 +# Root directory for saving snapshots generated via the `/index/api/getSnap` API. snapRoot=./www/snap/ -#默认截图图片,在启动FFmpeg截图后但是截图还未生成时,可以返回默认的预设图片 + +# 默认截图图片,在启动FFmpeg截图后但是截图还未生成时,可以返回默认的预设图片 +# Default placeholder image returned while FFmpeg is generating the actual snapshot. defaultSnap=./www/logo.png -#downloadFile http接口可访问文件的根目录,支持多个目录,不同目录通过分号(;)分隔 + +# downloadFile http接口可访问文件的根目录,支持多个目录,不同目录通过分号(;)分隔 +# Root directories accessible via the `downloadFile` API. Separate multiple directories with semicolons (;). downloadRoot=./www [ffmpeg] -#FFmpeg可执行程序路径,支持相对路径/绝对路径 +# FFmpeg可执行程序路径,支持相对路径/绝对路径 +# Path to the FFmpeg executable. Both relative and absolute paths are supported. bin=/usr/bin/ffmpeg -#FFmpeg拉流再推流的命令模板,通过该模板可以设置再编码的一些参数 + +# FFmpeg拉流再推流的命令模板,通过该模板可以设置诸如编码等的一些参数 +# FFmpeg command template for pulling and re-publishing streams (used to define re-encoding parameters). cmd=%s -re -i %s -c:a aac -strict -2 -ar 44100 -ab 48k -c:v libx264 -f flv %s -#FFmpeg生成截图的命令,可以通过修改该配置改变截图分辨率或质量 + +# FFmpeg生成截图的命令,可以通过修改该配置改变截图分辨率或质量 +# FFmpeg command template for generating snapshots. Modify this to change resolution or quality. snap=%s -i %s -y -f mjpeg -frames:v 1 -an %s -#FFmpeg日志的路径,如果置空则不生成FFmpeg日志 -#可以为相对(相对于本可执行程序目录)或绝对路径 + +# FFmpeg日志的路径,如果置空则不生成FFmpeg日志 +# 可以为相对(相对于本可执行程序目录)或绝对路径 +# Path to the FFmpeg log file (relative or absolute). Leave empty to disable logging. log=./ffmpeg/ffmpeg.log + # 自动重启的时间(秒), 默认为0, 也就是不自动重启. 主要是为了避免长时间ffmpeg拉流导致的不同步现象 +# Automatic restart interval in seconds (0 to disable). Helps prevent A/V desync caused by prolonged FFmpeg stream pulling. restart_sec=0 -#转协议相关开关;如果addStreamProxy api和on_publish hook回复未指定转协议参数,则采用这些配置项 +# 转协议相关开关;如果addStreamProxy api和on_publish hook回复未指定转协议参数,则采用这些配置项 +# Protocol conversion default switches. Used if protocol conversions aren't specified via the `addStreamProxy` API or the `on_publish` webhook. [protocol] -#转协议时,是否开启帧级时间戳覆盖 +# 转协议时,是否开启帧级时间戳覆盖 # 0:采用源视频流绝对时间戳,不做任何改变 # 1:采用zlmediakit接收数据时的系统时间戳(有平滑处理) # 2:采用源视频流时间戳相对时间戳(增长量),有做时间戳跳跃和回退矫正 +# Frame-level timestamp override mode during protocol conversion: +# - 0: Use absolute timestamp from the source (no modification). +# - 1: Use ZLMediaKit system timestamp upon data reception (with smoothing). +# - 2: Use relative timestamp increments, with correction for jumps and backwards drifts. modify_stamp=2 -#转协议是否开启音频 + +# 转协议是否开启音频 +# Whether to enable audio output during protocol conversion. enable_audio=1 -#添加acc静音音频,在关闭音频时,此开关无效 + +# 添加AAC静音音频,在关闭音频时,此开关无效 +# Whether to inject AAC silent audio (ignored if `enable_audio` is 0). add_mute_audio=1 -#无人观看时,是否直接关闭(而不是通过on_none_reader hook返回close) -#此配置置1时,此流如果无人观看,将不触发on_none_reader hook回调, -#而是将直接关闭流 + +# 无人观看时,是否直接关闭(而不是通过on_none_reader hook返回close) +# 此配置置1时,此流如果无人观看,将不触发on_none_reader hook回调, +# 而是将直接关闭流 +# Whether to immediately close an unwatched stream directly instead of relying on the `on_none_reader` hook returning 'close'. +# If enabled (1), an unwatched stream is closed outright without triggering the hook callback. auto_close=0 -#推流断开后可以在超时时间内重新连接上继续推流,这样播放器会接着播放。 -#置0关闭此特性(推流断开会导致立即断开播放器) -#此参数不应大于播放器超时时间;单位毫秒 +# 推流断开后可以在超时时间内重新连接上继续推流,这样播放器会接着播放。 +# 置0关闭此特性(推流断开会导致立即断开播放器) +# 此参数不应大于播放器超时时间;单位毫秒 +# Defines a grace period (in milliseconds) allowing a disconnected publisher to reconnect and resume streaming. +# During this period, active player connections are maintained rather than dropped. +# Set to 0 to disable this feature, which means dropping a publisher will immediately disconnect all its current players. +# This value must not exceed the player's configured timeout. continue_push_ms=15000 # 是否启用音频转码 # 主要实现进出RTC音频流的自动转码,代码实现详见 RtcMediaSource.h/cpp,当前实现 @@ -60,112 +101,204 @@ continue_push_ms=15000 # 此外音频转码正常都是用于webrtc的,一般也会开启WEBRTC, 即-DENABLE_WEBRTC=1, 此前必须自己装好libsrtp库, 安装过程详见wiki # audio_transcode配置项可通过配置文件,hook来打开,注意如果编译时没启用FFMPEG,此选项会自动关闭,使用此分支前得先确保启用FFMPEG! audio_transcode=1 -#平滑发送定时器间隔,单位毫秒,置0则关闭;开启后影响cpu性能同时增加内存 -#该配置开启后可以解决一些流发送不平滑导致zlmediakit转发也不平滑的问题 +# 平滑发送定时器间隔,单位毫秒,置0则关闭;开启后影响cpu性能同时增加内存 +# 该配置开启后可以解决一些流发送不平滑导致zlmediakit转发也不平滑的问题 +# Smooth sending timer interval in milliseconds (0 to disable). Enabling this increases CPU and memory usage. +# This solves the issue where unsteady upstream publishing causes ZLMediaKit's forwarding to also be unsteady. paced_sender_ms=0 -#是否开启转换为hls(mpegts) +# 是否开启转换为hls(mpegts) +# Whether to enable conversion to HLS (mpegts). enable_hls=1 -#是否开启转换为hls(fmp4) + +# 是否开启转换为hls(fmp4) +# Whether to enable conversion to HLS (fmp4). enable_hls_fmp4=0 -#是否开启MP4录制 + +# 是否开启MP4录制 +# Whether to enable MP4 recording. enable_mp4=0 -#是否开启转换为rtsp + +# 是否开启转换为rtsp +# Whether to enable conversion to RTSP. enable_rtsp=1 -#是否开启转换为webrtc +# 是否开启转换为rtc +# Whether to enable conversion to WEBRTC. enable_rtc=1 -#是否开启转换为rtmp/flv + +# 是否开启转换为rtmp/flv +# Whether to enable conversion to RTMP/FLV. enable_rtmp=1 -#是否开启转换为http-ts/ws-ts + +# 是否开启转换为http-ts/ws-ts +# Whether to enable conversion to HTTP-TS/WS-TS. enable_ts=1 -#是否开启转换为http-fmp4/ws-fmp4 + +# 是否开启转换为http-fmp4/ws-fmp4 +# Whether to enable conversion to HTTP-FMP4/WS-FMP4. enable_fmp4=1 -#是否将mp4录制当做观看者 +# 是否将mp4录制当做观看者 +# Whether to treat MP4 recording tasks as active stream viewers. mp4_as_player=0 -#mp4切片大小,单位秒 + +# mp4切片大小,单位秒 +# Maximum duration of MP4 recording segments in seconds. mp4_max_second=3600 -#mp4录制保存路径 + +# mp4录制保存路径 +# Directory path for saving MP4 recordings. mp4_save_path=./www -#hls录制保存路径 +# hls录制保存路径 +# Directory path for saving HLS recordings. hls_save_path=./www ###### 以下是按需转协议的开关,在测试ZLMediaKit的接收推流性能时,请把下面开关置1 -###### 如果某种协议你用不到,你可以把以下开关置1以便节省资源(但是还是可以播放,只是第一个播放者体验稍微差点), -###### 如果某种协议你想获取最好的用户体验,请置0(第一个播放者可以秒开,且不花屏) -#hls协议是否按需生成,如果hls.segNum配置为0(意味着hls录制),那么hls将一直生成(不管此开关) +###### 对于不使用的协议,可以将开关设置为 1 以节省资源(虽然首个播放者体验稍差,但依然可以播放)。 +###### 对于希望获得最佳用户体验的协议,请设置为 0(首屏秒开,且无花屏现象)。 +###### On-demand protocol conversion switches. Set these to 1 during stream reception performance testing to save resources. +###### For unused protocols, setting them to 1 saves resources (with a slight startup delay for the first viewer). +###### For the best user experience (instant playback and no visual artifacts (glitches)), set them to 0. + +# hls协议是否按需生成,如果hls.segNum配置为0(意味着hls录制),那么hls将一直生成(不管此开关) +# Whether to generate HLS streams on demand. If `hls.segNum` is configured to 0 (implies HLS recording), HLS streams generate continuously regardless of this switch. hls_demand=0 -#rtsp[s]协议是否按需生成 + +# rtsp[s]协议是否按需生成 +# Whether to generate RTSP[S] streams on demand. rtsp_demand=0 -#rtc协议是否按需生成 + +# rtc协议是否按需生成 +# Whether to generate WEBRTC streams on demand. rtc_demand=0 -#rtmp[s]、http[s]-flv、ws[s]-flv协议是否按需生成 + + +# rtmp[s]、http[s]-flv、ws[s]-flv协议是否按需生成 +# Whether to generate RTMP[S], HTTP[S]-FLV, and WS[S]-FLV streams on demand. rtmp_demand=0 -#http[s]-ts协议是否按需生成 + +# http[s]-ts协议是否按需生成 +# Whether to generate HTTP[S]-TS streams on demand. ts_demand=0 -#http[s]-fmp4、ws[s]-fmp4协议是否按需生成 + +# http[s]-fmp4、ws[s]-fmp4协议是否按需生成 +# Whether to generate HTTP[S]-FMP4 and WS[S]-FMP4 streams on demand. fmp4_demand=0 [general] -#是否启用虚拟主机 +# 是否启用虚拟主机 +# Whether to enable virtual hosting. enableVhost=0 -#播放器或推流器在断开后会触发hook.on_flow_report事件(使用多少流量事件), -#flowThreshold参数控制触发hook.on_flow_report事件阈值,使用流量超过该阈值后才触发,单位KB + +# 播放器或推流器在断开后会触发hook.on_flow_report事件(使用多少流量事件), +# flowThreshold参数控制触发hook.on_flow_report事件阈值,使用流量超过该阈值后才触发,单位KB +# When a player or publisher disconnects, it triggers the `hook.on_flow_report` event (an event reporting how much traffic was used). +# The `flowThreshold` parameter controls the threshold for triggering the `hook.on_flow_report` event; it is only triggered when the used traffic exceeds this threshold, in KB. flowThreshold=1024 -#播放最多等待时间,单位毫秒 -#播放在播放某个流时,如果该流不存在, -#ZLMediaKit会最多让播放器等待maxStreamWaitMS毫秒 -#如果在这个时间内,该流注册成功,那么会立即返回播放器播放成功 -#否则返回播放器未找到该流,该机制的目的是可以先播放再推流 + +# 播放最多等待时间,单位毫秒 +# 播放在播放某个流时,如果该流不存在, +# ZLMediaKit会最多让播放器等待maxStreamWaitMS毫秒 +# 如果在这个时间内,该流注册成功,那么会立即返回播放器播放成功 +# 否则返回播放器未找到该流,该机制的目的是可以先播放再推流 +# Maximum playback wait time in milliseconds. +# When a requested stream does not exist, ZLMediaKit delays the player for up to `maxStreamWaitMS`. +# If the stream is successfully registered within this period, it immediately returns playback success. +# Otherwise, it returns 'stream not found'. This mechanism enables 'play before push' workflows. maxStreamWaitMS=15000 -#某个流无人观看时,触发hook.on_stream_none_reader事件的最大等待时间,单位毫秒 -#在配合hook.on_stream_none_reader事件时,可以做到无人观看自动停止拉流或停止接收推流 + +# 某个流无人观看时,触发hook.on_stream_none_reader事件的最大等待时间,单位毫秒 +# 在配合hook.on_stream_none_reader事件时,可以做到无人观看自动停止拉流或停止接收推流 +# The continuous unwatched duration (in ms) required to trigger the `hook.on_stream_none_reader` event. +# Combined with the `hook.on_stream_none_reader` event, this enables automatically stopping origin pulls or disconnecting publishers when a stream remains unwatched. streamNoneReaderDelayMS=20000 -#拉流代理时如果断流再重连成功是否删除前一次的媒体流数据,如果删除将重新开始, -#如果不删除将会接着上一次的数据继续写(录制hls/mp4时会继续在前一个文件后面写) + +# 拉流代理时如果断流再重连成功是否删除前一次的媒体流数据,如果删除将重新开始, +# 如果不删除将会接着上一次的数据继续写(录制hls/mp4时会继续在前一个文件后面写) +# Whether to flush cached media data upon successfully reconnecting after an origin pull proxy disconnection. If flushed, the stream restarts cleanly. +# If not flushed, the new data will append directly to the previous data (when recording HLS/MP4, it continues appending to the previous file). resetWhenRePlay=1 -#合并写缓存大小(单位毫秒),合并写指服务器缓存一定的数据后才会一次性写入socket,这样能提高性能,但是会提高延时 -#开启后会同时关闭TCP_NODELAY并开启MSG_MORE + +# 合并写缓存大小(单位毫秒),合并写指服务器缓存一定的数据后才会一次性写入socket,这样能提高性能,但是会提高延时 +# 开启后会同时关闭TCP_NODELAY并开启MSG_MORE +# Write coalescing cache duration in ms. The server caches data up to this interval before writing to the socket in bulk, improving performance at the cost of slight latency. +# Enabling this disables `TCP_NODELAY` and enables `MSG_MORE`. mergeWriteMS=0 -#服务器唯一id,用于触发hook时区别是哪台服务器 + +# 服务器唯一id,用于触发hook时区别是哪台服务器 +# Unique server ID, used to distinguish which server it is when triggering a hook. mediaServerId=your_server_id -#最多等待未初始化的Track时间,单位毫秒,超时之后会忽略未初始化的Track +# 最多等待未初始化的Track时间,单位毫秒,超时之后会忽略未初始化的Track +# Maximum wait time (in ms) for uninitialized Tracks. After the timeout, any uninitialized Tracks will be ignored. wait_track_ready_ms=10000 -#最多等待音频Track收到数据时间,单位毫秒,超时且完全没收到音频数据,忽略音频Track -#加快某些带封装的流metadata说明有音频,但是实际上没有的流ready时间(比如很多厂商的GB28181 PS) + +# 最多等待音频Track收到数据时间,单位毫秒,超时且完全没收到音频数据,忽略音频Track +# 加快某些带封装的流metadata说明有音频,但是实际上没有的流ready时间(比如很多厂商的GB28181 PS) +# Maximum wait time (in ms) before an audio track receives its first data packet. If it times out and absolutely no audio data has been received, the audio Track is ignored. +# This speeds up the ready time for certain packaged streams whose metadata falsely claims to include audio, but actually do not (e.g., GB28181 PS). wait_audio_track_data_ms=1000 -#如果流只有单Track,最多等待若干毫秒,超时后未收到其他Track的数据,则认为是单Track -#如果协议元数据有声明特定track数,那么无此等待时间 + +# 如果流只有单Track,最多等待若干毫秒,超时后未收到其他Track的数据,则认为是单Track +# 如果协议元数据有声明特定track数,那么无此等待时间 +# Maximum wait time (in ms) for additional tracks if a stream currently has only one. +# If no data from other tracks is received within this timeout, it is considered a single-track stream. +# This delay is bypassed if protocol metadata explicitly declares the track count. wait_add_track_ms=3000 -#如果track未就绪,我们先缓存帧数据,但是有最大个数限制,防止内存溢出 + +# 如果track未就绪,我们先缓存帧数据,但是有最大个数限制,防止内存溢出 +# If a track is not ready, we first cache the frame data, but there is a maximum count limit to prevent memory overflow. unready_frame_cache=100 -#是否启用观看人数变化事件广播,置1则启用,置0则关闭 + +# 是否启用观看人数变化事件广播,置1则启用,置0则关闭 +# Whether to enable broadcasting of viewership change events. Set to 1 to enable, set to 0 to disable. broadcast_player_count_changed=0 -#绑定的本地网卡ip + +# 绑定的本地网卡ip +# Bound local network interface IP address. listen_ip=:: [hls] -#hls写文件的buf大小,调整参数可以提高文件io性能 +# hls写文件的buf大小,调整参数可以提高文件io性能 +# Buffer size used when writing HLS segment files. Increasing this value can improve disk I/O performance. fileBufSize=65536 -#hls最大切片时间 + +# hls最大切片时间 +# Target maximum duration of a single HLS segment. segDur=2 -#m3u8索引中,hls保留切片个数(实际保留切片个数+segRetain个) -#如果设置为0,则不删除切片且m3u8文件全量记录切片列表 + +# m3u8索引中,hls保留切片个数(实际保留切片个数+segRetain个) +# 如果设置为0,则不删除切片且m3u8文件全量记录切片列表 +# Number of HLS segments retained within the m3u8 playlist index (actual chunks kept = this value + `segRetain`). +# Set to 0 to retain all segments and record the full segment list in the m3u8 file. segNum=3 -#HLS切片延迟个数,大于0将生成hls_delay.m3u8文件,0则不生成 + +# HLS切片延迟个数,大于0将生成hls_delay.m3u8文件,0则不生成 +# The segment delay count for HLS. If greater than 0, an `hls_delay.m3u8` variant playlist is generated; if 0, it will not be generated. segDelay=0 -#HLS切片从m3u8文件中移除后,继续保留在磁盘上的个数 + +# HLS切片从m3u8文件中移除后,继续保留在磁盘上的个数 +# Number of outdated HLS segments to keep on disk after removal from the m3u8 playlist. segRetain=5 -#是否广播 hls切片(ts/fmp4)完成通知(on_record_ts) + +# 是否广播 hls切片(ts/fmp4)完成通知(on_record_ts) +# Whether to broadcast HLS segment (TS/FMP4) completion notifications via `on_record_ts`. broadcastRecordTs=0 -#直播hls文件删除延时,单位秒,issue: #913 + +# 直播hls文件删除延时,单位秒,issue: #913 +# Delay in seconds before deleting expired live HLS segments. Refer to issue: #913. deleteDelaySec=10 -#此选项开启后m3u8文件还是表现为直播,但是切片文件会被全部保留为点播用 -#segDur设置为0或segKeep设置为1的情况下,每个切片文件夹下会生成一个vod.m3u8文件用于点播该时间段的录像 + +# 此选项开启后m3u8文件还是表现为直播,但是切片文件会被全部保留为点播用 +# segDur设置为0或segKeep设置为1的情况下,每个切片文件夹下会生成一个vod.m3u8文件用于点播该时间段的录像 +# When enabled, the `m3u8` playlist functions as live media, but segment chunks are permanently preserved in storage for Video On Demand (VOD). +# If either `segKeep` is 1 or `segDur` is 0, a `vod.m3u8` playlist is also generated in each segment's folder for VOD playback of that specific recorded period. segKeep=0 -#如果设置为1,则第一个切片长度强制设置为1个GOP。当GOP小于segDur,可以提高首屏速度 + +# 如果设置为1,则第一个切片长度强制设置为1个GOP。当GOP小于segDur,可以提高首屏速度 +# If set to 1, the length of the first segment is forcibly set to exactly 1 GOP. +# When the GOP is smaller than `segDur`, this can improve the initial startup (instant playback) speed. fastRegister=0 # 转码成opus音频时的比特率 opusBitrate=64000 @@ -173,286 +306,475 @@ opusBitrate=64000 aacBitrate=64000 [hook] -#是否启用hook事件,启用后,推拉流都将进行鉴权 +# 是否启用hook事件,启用后,推拉流都将进行鉴权 +# Whether to enable webhook events. When enabled, pushing and pulling streams requires authentication. enable=0 -#播放器或推流器使用流量事件,置空则关闭 +# 播放器或推流器使用流量事件,置空则关闭 +# Player or publisher flow traffic event. Leave empty to disable. on_flow_report= -#访问http文件鉴权事件,置空则关闭鉴权 +# 访问http文件鉴权事件,置空则关闭鉴权 +# HTTP file access authentication event. Leave empty to disable. on_http_access= -#播放鉴权事件,置空则关闭鉴权 +# 播放鉴权事件,置空则关闭鉴权 +# Playback authentication event. Leave empty to disable. on_play= -#推流鉴权事件,置空则关闭鉴权 +# 推流鉴权事件,置空则关闭鉴权 +# Publishing authentication event. Leave empty to disable. on_publish= -#录制mp4切片完成事件 +# 录制mp4切片完成事件 +# MP4 segment recording completion event. on_record_mp4= # 录制 hls ts(或fmp4) 切片完成事件 +# HLS TS (or fmp4) segment recording completion event. on_record_ts= -#rtsp播放鉴权事件,此事件中比对rtsp的用户名密码 +# rtsp播放鉴权事件,此事件中比对rtsp的用户名密码 +# RTSP playback authentication event (used to verify RTSP username and password). on_rtsp_auth= -#rtsp播放是否开启专属鉴权事件,置空则关闭rtsp鉴权。rtsp播放鉴权还支持url方式鉴权 -#建议开发者统一采用url参数方式鉴权,rtsp用户名密码鉴权一般在设备上用的比较多 -#开启rtsp专属鉴权后,将不再触发on_play鉴权事件 +# rtsp播放是否开启专属鉴权事件,置空则关闭rtsp鉴权。rtsp播放鉴权还支持url方式鉴权 +# 建议开发者统一采用url参数方式鉴权,rtsp用户名密码鉴权一般在设备上用的比较多 +# 开启rtsp专属鉴权后,将不再触发on_play鉴权事件 +# Whether to enable a dedicated RTSP realm authentication event (leave empty to disable; URL-based auth remains supported). +# We recommend standardizing on URL parameters; RTSP username/password auth is mostly for hardware devices. +# Enabling this bypasses the standard `on_play` webhook. on_rtsp_realm= -#远程telnet调试鉴权事件 +# 远程telnet调试鉴权事件 +# Remote telnet debugging authentication event. on_shell_login= -#直播流注册或注销事件 +# 直播流注册或注销事件 +# Live stream registration or unregistration event. on_stream_changed= -#过滤on_stream_changed hook的协议类型,可以选择只监听某些感兴趣的协议;置空则不过滤协议 +# 过滤on_stream_changed hook的协议类型,可以选择只监听某些感兴趣的协议;置空则不过滤协议 +# Filter the protocol types for the `on_stream_changed` hook to listen only to specific protocols. Leave empty to disable filtering. stream_changed_schemas=rtsp/rtmp/fmp4/ts/hls/hls.fmp4 -#无人观看流事件,通过该事件,可以选择是否关闭无人观看的流。配合general.streamNoneReaderDelayMS选项一起使用 +# 无人观看流事件,通过该事件,可以选择是否关闭无人观看的流。配合general.streamNoneReaderDelayMS选项一起使用 +# Triggered when a stream has no viewers. Combined with `general.streamNoneReaderDelayMS`, this enables closing unwatched streams. on_stream_none_reader= -#播放时,未找到流事件,通过配合hook.on_stream_none_reader事件可以完成按需拉流 +# 播放时,未找到流事件,通过配合hook.on_stream_none_reader事件可以完成按需拉流 +# Triggered when a requested stream is not found. Combined with `hook.on_stream_none_reader`, this enables on-demand origin pulling. on_stream_not_found= -#服务器启动报告,可以用于服务器的崩溃重启事件监听 +# 服务器启动报告,可以用于服务器的崩溃重启事件监听 +# Server startup report. Useful for monitoring server crashes and restarts. on_server_started= -#服务器退出报告,当服务器正常退出时触发 +# 服务器退出报告,当服务器正常退出时触发 +# Server exit report, triggered when the server shuts down normally. on_server_exited= -#server保活上报 +# server保活上报 +# Server keep-alive reporting event. on_server_keepalive= -#发送rtp(startSendRtp)被动关闭时回调 +# 发送rtp(startSendRtp)被动关闭时回调 +# Callback triggered when RTP sending (`startSendRtp`) is passively closed. on_send_rtp_stopped= -#rtp server 超时未收到数据 +# rtp server 超时未收到数据 +# RTP server timeout event due to not receiving data. on_rtp_server_timeout= - -#hook api最大等待回复时间,单位秒 +# hook api最大等待回复时间,单位秒 +# Maximum wait time in seconds for Webhook API responses. timeoutSec=10 -#keepalive hook触发间隔,单位秒,float类型 +# keepalive hook触发间隔,单位秒,float类型 +# Interval in seconds (float) for triggering the keep-alive webhook. alive_interval=10.0 -#hook通知失败重试次数,正整数。为0不重试,1时重试一次,以此类推 +# hook通知失败重试次数,正整数。为0不重试,1时重试一次,以此类推 +# Webhook notification failure retry attempts. Must be a non-negative integer (0 to disable). retry=1 -#hook通知失败重试延时,单位秒,float型 +# hook通知失败重试延时,单位秒,float型 +# Delay in seconds (float) between webhook retry attempts. retry_delay=3.0 [cluster] -#设置源站拉流url模板, 格式跟printf类似,第一个%s指定app,第二个%s指定stream_id, -#开启集群模式后,on_stream_not_found和on_stream_none_reader hook将无效. -#溯源模式支持以下类型: -#rtmp方式: rtmp://127.0.0.1:1935/%s/%s -#rtsp方式: rtsp://127.0.0.1:554/%s/%s -#hls方式: http://127.0.0.1:80/%s/%s/hls.m3u8 -#http-ts方式: http://127.0.0.1:80/%s/%s.live.ts -#支持多个源站,不同源站通过分号(;)分隔 +# 设置源站拉流url模板, 格式跟printf类似,第一个%s指定app,第二个%s指定stream_id, +# 开启集群模式后,on_stream_not_found和on_stream_none_reader hook将无效. +# 溯源模式支持以下类型: +# rtmp方式: rtmp://127.0.0.1:1935/%s/%s +# rtsp方式: rtsp://127.0.0.1:554/%s/%s +# hls方式: http://127.0.0.1:80/%s/%s/hls.m3u8 +# http-ts方式: http://127.0.0.1:80/%s/%s.live.ts +# 支持多个源站,不同源站通过分号(;)分隔 +# Origin pull URL template (printf style: first `%s` is app, second `%s` is stream_id). +# When cluster mode is enabled, `on_stream_not_found` and `on_stream_none_reader` webhooks are disabled. +# Supported origin pull protocols: +# RTMP mode: rtmp://127.0.0.1:1935/%s/%s +# RTSP mode: rtsp://127.0.0.1:554/%s/%s +# HLS mode: http://127.0.0.1:80/%s/%s/hls.m3u8 +# HTTP-TS mode: http://127.0.0.1:80/%s/%s.live.ts +# Separate multiple origin servers with semicolons (;). origin_url= -#溯源总超时时长,单位秒,float型;假如源站有3个,那么单次溯源超时时间为timeout_sec除以3 -#单次溯源超时时间不要超过general.maxStreamWaitMS配置 +# 溯源总超时时长,单位秒,float型;假如源站有3个,那么单次溯源超时时间为timeout_sec除以3 +# 单次溯源超时时间不要超过general.maxStreamWaitMS配置 +# Total origin pull timeout in seconds (float). +# The single origin attempt timeout (total timeout divided by the number of origins) should not exceed `general.maxStreamWaitMS`. timeout_sec=15 -#溯源失败尝试次数,-1时永久尝试 +# 溯源失败尝试次数,-1时永久尝试 +# Failure retry attempts for origin pulling (-1 for infinite retries). retry_count=3 [http] -#http服务器字符编码集 +# http服务器字符编码集 +# HTTP server character encoding. charSet=utf-8 -#http链接超时时间 +# http链接超时时间 +# HTTP connection timeout in seconds. keepAliveSecond=30 -#http请求体最大字节数,如果post的body太大,则不适合缓存body在内存 +# http请求体最大字节数,如果post的body太大,则不适合缓存body在内存 +# Maximum number of bytes for the HTTP request body. If the POST body is too large, it is not suitable to cache the body in memory. maxReqSize=40960 -#404网页内容,用户可以自定义404网页 +# 404网页内容,用户可以自定义404网页 +# Custom 404 page content. Users can customize the 404 response page here. #notFound=404 Not Found

您访问的资源不存在!


ZLMediaKit-4.0
-#http服务器监听端口 +# http服务器监听端口 +# HTTP server listening port. port=80 -#http文件服务器根目录 -#可以为相对(相对于本可执行程序目录)或绝对路径 +# http文件服务器根目录 +# 可以为相对(相对于本可执行程序目录)或绝对路径 +# HTTP file server root directory (relative or absolute path). rootPath=./www -#http文件服务器读文件缓存大小,单位BYTE,调整该参数可以优化文件io性能 +# http文件服务器读文件缓存大小,单位BYTE,调整该参数可以优化文件io性能 +# HTTP file server read cache size in bytes. Tweak to optimize file I/O performance. sendBufSize=65536 -#https服务器监听端口 +# https服务器监听端口 +# HTTPS server listening port. sslport=443 -#是否显示文件夹菜单,开启后可以浏览文件夹 +# 是否显示文件夹菜单,开启后可以浏览文件夹 +# Whether to enable directory browsing menus. dirMenu=1 -#虚拟目录, 虚拟目录名和文件路径使用","隔开,多个配置路径间用";"隔开 -#例如赋值为 app_a,/path/to/a;app_b,/path/to/b 那么 -#访问 http://127.0.0.1/app_a/file_a 对应的文件路径为 /path/to/a/file_a -#访问 http://127.0.0.1/app_b/file_b 对应的文件路径为 /path/to/b/file_b -#访问其他http路径,对应的文件路径还是在rootPath内 +# 虚拟目录, 虚拟目录名和文件路径使用","隔开,多个配置路径间用";"隔开 +# 例如赋值为 app_a,/path/to/a;app_b,/path/to/b 那么 +# 访问 http://127.0.0.1/app_a/file_a 对应的文件路径为 /path/to/a/file_a +# 访问 http://127.0.0.1/app_b/file_b 对应的文件路径为 /path/to/b/file_b +# 访问其他http路径,对应的文件路径还是在rootPath内 +# Virtual directory mappings. Format: virtual_name,path;virtual_name,path (name and file path separated by ",", multiple mappings separated by ";"). +# For example, set `app_a,/path/to/a;app_b,/path/to/b` then: +# Accessing `http://127.0.0.1/app_a/file_a` maps to `/path/to/a/file_a`. +# Accessing `http://127.0.0.1/app_b/file_b` maps to `/path/to/b/file_b`, while other HTTP paths still map to files under `rootPath`. virtualPath= -#禁止后缀的文件使用mmap缓存,使用“,”隔开 -#例如赋值为 .mp4,.flv -#那么访问后缀为.mp4与.flv 的文件不缓存 +# 禁止后缀的文件使用mmap缓存,使用“,”隔开 +# 例如赋值为 .mp4,.flv +# 那么访问后缀为.mp4与.flv 的文件不缓存 +# Disables `mmap` caching for specific file extensions. Use `,` to separate multiple extensions. +# Example: `.mp4,.flv` means files with these extensions bypass the `mmap` cache. forbidCacheSuffix= -#可以把http代理前真实客户端ip放在http头中:https://github.com/ZLMediaKit/ZLMediaKit/issues/1388 -#切勿暴露此key,否则可能导致伪造客户端ip +# 可以把http代理前真实客户端ip放在http头中:https://github.com/ZLMediaKit/ZLMediaKit/issues/1388 +# 切勿暴露此key,否则可能导致伪造客户端ip +# Header name to trust for extracting the real client IP from an HTTP proxy request header. See: https://github.com/ZLMediaKit/ZLMediaKit/issues/1388 +# Do not expose this key, as it may lead to forged client IPs. forwarded_ip_header= -#默认允许所有跨域请求 +# 默认允许所有跨域请求 +# Whether to allow all cross-origin requests by default (sets generic CORS headers). allow_cross_domains=1 -#允许访问http api和http文件索引的ip地址范围白名单,置空情况下不做限制 +# 允许访问http api和http文件索引的ip地址范围白名单,置空情况下不做限制 +# IP whitelist ranges allowed to access the HTTP API and file indexes. Leave empty to allow any IP without restrictions. allow_ip_range=::1,127.0.0.1,172.16.0.0-172.31.255.255,192.168.0.0-192.168.255.255,10.0.0.0-10.255.255.255 [multicast] -#rtp组播截止组播ip地址 +# rtp组播截止组播ip地址 +# Maximum IP address for the multicast pool. addrMax=239.255.255.255 -#rtp组播起始组播ip地址 +# rtp组播起始组播ip地址 +# Minimum IP address for the multicast pool. addrMin=239.0.0.0 -#组播udp ttl +# 组播udp ttl +# TTL (Time to Live) for multicast UDP packets. udpTTL=64 [record] -#mp4录制或mp4点播的应用名,通过限制应用名,可以防止随意点播 -#点播的文件必须放置在此文件夹下 +# mp4录制或mp4点播的应用名,通过限制应用名,可以防止随意点播 +# 点播的文件必须放置在此文件夹下 +# Application name for MP4 recording/VOD. Restricting this prevents unauthorized VOD access. +# VOD files must be placed within this specific folder. appName=record -#mp4录制写文件缓存,单位BYTE,调整参数可以提高文件io性能 +# mp4录制写文件缓存,单位BYTE,调整参数可以提高文件io性能 +# MP4 recording write cache size in bytes. Tweak to optimize file I/O performance. fileBufSize=65536 -#mp4点播每次流化数据量,单位毫秒, -#减少该值可以让点播数据发送量更平滑,增大该值则更节省cpu资源 +# mp4点播每次流化数据量,单位毫秒, +# 减少该值可以让点播数据发送量更平滑,增大该值则更节省cpu资源 +# Duration (in ms) of MP4 data streamed per VOD transmission block. +# Decreasing this value smooths transmission; increasing it saves CPU resources. sampleMS=500 -#mp4录制完成后是否进行二次关键帧索引写入头部 +# mp4录制完成后是否进行二次关键帧索引写入头部 +# Whether to write a secondary keyframe index into the MP4 header after recording completes (fast start). fastStart=0 -#MP4点播(rtsp/rtmp/http-flv/ws-flv)是否循环播放文件 +# MP4点播(rtsp/rtmp/http-flv/ws-flv)是否循环播放文件 +# Controls whether MP4 VOD playback (rtsp/rtmp/http-flv/ws-flv) loops the file when it reaches the end. fileRepeat=0 -#MP4录制写文件格式是否采用fmp4,启用的话,断电未完成录制的文件也能正常打开 +# MP4录制写文件格式是否采用fmp4,启用的话,断电未完成录制的文件也能正常打开 +# Whether to use the fmp4 format for MP4 recording. Enables normal playback of interrupted recordings (e.g., due to power loss). enableFmp4=0 [rtmp] -#rtmp必须在此时间内完成握手,否则服务器会断开链接,单位秒 +# rtmp必须在此时间内完成握手,否则服务器会断开链接,单位秒 +# RTMP handshake timeout in seconds. The server drops the connection if not completed. handshakeSecond=15 -#rtmp超时时间,如果该时间内未收到客户端的数据, -#或者tcp发送缓存超过这个时间,则会断开连接,单位秒 +# rtmp超时时间,如果该时间内未收到客户端的数据, +# 或者tcp发送缓存超过这个时间,则会断开连接,单位秒 +# RTMP keep-alive timeout in seconds. Connections drop if no data from the client is received, +# or if the TCP send buffer stall exceeds this duration. keepAliveSecond=15 -#rtmp服务器监听端口 +# rtmp服务器监听端口 +# RTMP server listening port. port=1935 -#rtmps服务器监听地址 +# rtmps服务器监听地址 +# RTMPS server listening port. sslport=0 # rtmp是否直接代理模式 +# Whether to enable direct proxy mode for RTMP. directProxy=1 -#h265 rtmp打包采用增强型rtmp标准还是国内拓展标准 -enhanced=0 +# h265/opus/vp8/vp9/av1 rtmp打包采用增强型rtmp标准还是国内拓展标准 +# Whether RTMP packaging for H265/Opus/VP8/VP9/AV1 uses the Enhanced RTMP standard (1) or the domestic extended standard (0). +enhanced=1 [rtp] -#音频mtu大小,该参数限制rtp最大字节数,推荐不要超过1400 -#加大该值会明显增加直播延时 +# 音频mtu大小,该参数限制rtp最大字节数,推荐不要超过1400 +# 加大该值会明显增加直播延时 +# Audio MTU size (restricts max RTP payload in bytes). We recommend keeping this <= 1400. +# Increasing this value significantly increases live streaming latency. audioMtuSize=600 -#视频mtu大小,该参数限制rtp最大字节数,推荐不要超过1400 +# 视频mtu大小,该参数限制rtp最大字节数,推荐不要超过1400 +# Video MTU size (restricts max RTP payload in bytes). We recommend keeping this <= 1400. videoMtuSize=1400 -#rtp包最大长度限制,单位KB,主要用于识别TCP上下文破坏时,获取到错误的rtp +# rtp包最大长度限制,单位KB,主要用于识别TCP上下文破坏时,获取到错误的rtp +# Max RTP packet length in KB. Mainly used to identify receiving wrong RTP packets when TCP stream contexts are corrupted. rtpMaxSize=10 # rtp 打包时,低延迟开关,默认关闭(为0),h264存在一帧多个slice(NAL)的情况,在这种情况下,如果开启可能会导致画面花屏 +# Low-latency mode for RTP packaging (disabled by default). Enabling this for H.264 video with multiple slices per frame may cause visual artifacts (glitches). lowLatency=0 # H264 rtp打包模式是否采用stap-a模式(为了在老版本浏览器上兼容webrtc)还是采用Single NAL unit packet per H.264 模式 # 有些老的rtsp设备不支持stap-a rtp,设置此配置为0可提高兼容性 +# Whether H.264 RTP packaging uses the `stap-a` mode (for older WebRTC browser compatibility) or the `Single NAL unit packet per H.264` mode. +# Set this to 0 to improve compatibility with legacy RTSP devices that do not support `stap-a`. h264_stap_a=1 [rtp_proxy] -#导出调试数据(包括rtp/ps/h264)至该目录,置空则关闭数据导出 +# 导出调试数据(包括rtp/ps/h264)至该目录,置空则关闭数据导出 +# Directory for exporting debugging data (rtp/ps/h264). Leave empty to disable. dumpDir= -#udp和tcp代理服务器,支持rtp(必须是ts或ps类型)代理 +# udp和tcp代理服务器,支持rtp(必须是ts或ps类型)代理 +# UDP/TCP proxy server listening port. Supports RTP proxying (must be TS or PS). port=10000 -#rtp超时时间,单位秒 +# rtp超时时间,单位秒 +# RTP timeout in seconds. timeoutSec=15 -#随机端口范围,最少确保36个端口 -#该范围同时限制rtsp服务器udp端口范围 +# 随机端口范围,最少确保36个端口 +# 该范围同时限制rtsp服务器udp端口范围 +# Random port range (ensure at least 36 ports). +# This also restricts the UDP port range for the RTSP server. port_range=30000-35000 -#rtp h264 负载的pt +# rtp h264 负载的pt +# RTP payload type (PT) for H.264. h264_pt=98 -#rtp h265 负载的pt +# rtp h265 负载的pt +# RTP payload type (PT) for H.265. h265_pt=99 -#rtp ps 负载的pt +# rtp ps 负载的pt +# RTP payload type (PT) for PS. ps_pt=96 -#rtp opus 负载的pt +# rtp opus 负载的pt +# RTP payload type (PT) for Opus. opus_pt=100 -#RtpSender相关功能是否提前开启gop缓存优化级联秒开体验,默认开启 -#如果不调用startSendRtp相关接口,可以置0节省内存 +# startSendRtp、startRecord相关功能是否提前开启gop缓存优化级联秒开体验,默认开启, 并缓存1个GOP +# 如果不调用startSendRtp、startRecord后相关接口,可以置0节省内存;如果缓存多个gop,可以加大该参数 +# Whether to pre-enable GOP caching for `startSendRtp` and `startRecord` to optimize instant playback for cascaded streams. Enabled by default, caching 1 GOP. +# If these functions are unused, set to 0 to save memory; to cache multiple GOPs, increase this value. gop_cache=1 -#国标发送g711 rtp 打包时,每个包的语音时长是多少,默认是100 ms,范围为20~180ms (gb28181-2016,c.2.4规定), -#最好为20 的倍数,程序自动向20的倍数取整 -rtp_g711_dur_ms = 100 -#udp接收数据socket buffer大小配置 -#4*1024*1024=4196304 +# 国标发送g711 rtp 打包时,每个包的语音时长是多少,默认是100 ms,范围为20~180ms (gb28181-2016,c.2.4规定), +# 最好为20 的倍数,程序自动向20的倍数取整 +# Audio duration (in ms) per packet when packaging G.711 RTP for GB standards. Defaults to 100 ms (range: 20~180ms per gb28181-2016, c.2.4). +# A multiple of 20 is recommended; the program auto-rounds to the nearest multiple. +rtp_g711_dur_ms=100 +# udp接收数据socket buffer大小配置 +# 4*1024*1024=4196304 +# Socket buffer size for receiving UDP data. udp_recv_socket_buffer=4194304 +# ps/ts解析后是否等待下一帧以判断本帧是否完整,开启后提高兼容性,但是可能增加延时 +# Whether to wait for the next frame after parsing PS/TS to verify frame completeness. Improves compatibility but may increase latency. +merge_frame=1 [rtc] -#rtc播放推流、播放超时时间 +# webrtc 信令服务器端口 +# WebRTC signaling server port. +signalingPort=3000 +signalingSslPort=3001 +# STUN/TURN服务器端口 +# STUN/TURN server port. +icePort=3478 +iceTcpPort=3478 +# STUN/TURN端口是否使能TURN服务 +# Whether to enable TURN services on the STUN/TURN ports. +enableTurn=1 +# ICE传输策略:0=不限制(默认),1=仅支持Relay转发,2=仅支持P2P直连 +# ICE transport policy: 0 (No restrictions, default), 1 (Relay forwarding only), 2 (P2P direct connection only). +iceTransportPolicy=0 +# STUN/TURN 服务Ice密码 +# ICE credentials for STUN/TURN services. +iceUfrag=ZLMediaKit +icePwd=ZLMediaKit +# webrtc datachannel是否回显数据,测试用 +# Whether WebRTC Datachannel echoes received data (used for testing). +datachannel_echo=1 + +max_stun_retry=7 +# TURN服务分配端口池 +# Port range allocated for TURN services. +port_range=49152-65535 +# rtc播放推流、播放超时时间 +# Timeout in seconds for RTC stream publishing and playback. timeoutSec=15 -#本机对rtc客户端的可见ip,作为服务器时一般为公网ip,可有多个,用','分开,当置空时,会自动获取网卡ip -#同时支持环境变量,以$开头,如"$EXTERN_IP"; 请参考:https://github.com/ZLMediaKit/ZLMediaKit/pull/1786 +# 本机对rtc客户端的可见ip,作为服务器时一般为公网ip,可有多个,用','分开,当置空时,会自动获取网卡ip +# 同时支持环境变量,以$开头,如"$EXTERN_IP"; 请参考:https://github.com/ZLMediaKit/ZLMediaKit/pull/1786 +# IP address(es) visible to RTC clients (typically public IPs). Separate multiple IPs with commas (','). +# Leave empty to auto-acquire network card IPs. Also supports env vars starting with `$`, e.g., `"$EXTERN_IP"`; please refer to: https://github.com/ZLMediaKit/ZLMediaKit/pull/1786 externIP= -#rtc udp服务器监听端口号,所有rtc客户端将通过该端口传输stun/dtls/srtp/srtcp数据, -#该端口是多线程的,同时支持客户端网络切换导致的连接迁移 -#需要注意的是,如果服务器在nat内,需要做端口映射时,必须确保外网映射端口跟该端口一致 +# 当指定了interfaces,ICE服务器会使用指定网卡bind socket +# 以解决公网IP使用弹性公网IP配置实现(部署机器无法bind该公网ip的问题) +# 支持环境变量,以$开头,如"$PRIVATE_IP" +# If specified, the ICE server binds the socket to this specific network card. +# Solves binding issues on machines with Elastic Public IPs that cannot directly bind the public IP. +# Supports environment variables starting with `$`, e.g., `"$PRIVATE_IP"`. +interfaces= +# rtc udp服务器监听端口号,所有rtc客户端将通过该端口传输stun/dtls/srtp/srtcp数据, +# 该端口是多线程的,同时支持客户端网络切换导致的连接迁移 +# 需要注意的是,如果服务器在nat内,需要做端口映射时,必须确保外网映射端口跟该端口一致 +# RTC UDP server listening port. Handles STUN/DTLS/SRTP/SRTCP data for all RTC clients. +# Multi-threaded and supports connection migration during client network switching. +# Note: For deployment behind a NAT, the external mapped port MUST match this port exactly. port=8000 -#rtc tcp服务器监听端口号,在udp 不通的情况下,会使用tcp传输数据 -#该端口是多线程的,同时支持客户端网络切换导致的连接迁移 -#需要注意的是,如果服务器在nat内,需要做端口映射时,必须确保外网映射端口跟该端口一致 -tcpPort = 8000 -#设置remb比特率,非0时关闭twcc并开启remb。该设置在rtc推流时有效,可以控制推流画质 -#目前已经实现twcc自动调整码率,关闭remb根据真实网络状况调整码率 +# rtc tcp服务器监听端口号,在udp 不通的情况下,会使用tcp传输数据 +# 该端口是多线程的,同时支持客户端网络切换导致的连接迁移 +# 需要注意的是,如果服务器在nat内,需要做端口映射时,必须确保外网映射端口跟该端口一致 +# RTC TCP server listening port. Used as a fallback if UDP is unreachable. +# Multi-threaded and supports connection migration during client network switching. +# Note: For deployment behind a NAT, the external mapped port MUST match this port exactly. +tcpPort=8000 +# 设置remb比特率,非0时关闭twcc并开启remb。该设置在rtc推流时有效,可以控制推流画质 +# 目前已经实现twcc自动调整码率,关闭remb根据真实网络状况调整码率 +# REMB bitrate threshold. Non-zero values disable TWCC and enable REMB (effective for RTC publishing to control picture quality). +# ZLMediaKit natively supports automatic TWCC bitrate adjustment; disabling REMB allows rates to adjust naturally based on actual network conditions. rembBitRate=0 -#rtc支持的音频codec类型,在前面的优先级更高 -#以下范例为所有支持的音频codec +# rtc支持的音频codec类型,在前面的优先级更高 +# 以下范例为所有支持的音频codec +# Supported RTC audio codecs (listed in descending priority). preferredCodecA=PCMA,PCMU,opus,mpeg4-generic -#rtc支持的视频codec类型,在前面的优先级更高 -#以下范例为所有支持的视频codec +# rtc支持的视频codec类型,在前面的优先级更高 +# 以下范例为所有支持的视频codec +# Supported RTC video codecs (listed in descending priority). preferredCodecV=H264,H265,AV1,VP9,VP8 # 是否开启RTC协议的G711转码,开启后 # 能将传给rtc的g711音频转成opus # 将由rtc流入g711音频转成aac,并转给其他协议流 transcodeG711=0 -#webrtc比特率设置 +# webrtc比特率设置 +# WebRTC bitrate settings. start_bitrate=0 max_bitrate=0 min_bitrate=0 -#nack接收端, rtp发送端,zlm发送rtc流 -#rtp重发缓存列队最大长度,单位毫秒 +# nack接收端, rtp发送端,zlm发送rtc流 +# rtp重发缓存列队最大长度,单位毫秒 +# NACK receiver / RTP sender queue (ZLM sending RTC streams). +# Maximum length of the RTP retransmission cache queue in ms. maxRtpCacheMS=5000 -#rtp重发缓存列队最大长度,单位个数 +# rtp重发缓存列队最大长度,单位个数 +# Maximum length of the RTP retransmission cache queue in packet count. maxRtpCacheSize=2048 -#nack发送端,rtp接收端,zlm接收rtc推流 -#最大保留的rtp丢包状态个数 +# nack发送端,rtp接收端,zlm接收rtc推流 +# 最大保留的rtp丢包状态个数 +# NACK sender / RTP receiver queue (ZLM receiving RTC streams). +# Maximum number of retained RTP packet-loss states. nackMaxSize=2048 -#rtp丢包状态最长保留时间 +# rtp丢包状态最长保留时间 +# Maximum retention time for RTP packet-loss states in ms. nackMaxMS=3000 -#nack最多请求重传次数 +# nack最多请求重传次数 +# Maximum number of NACK retransmission requests. nackMaxCount=15 -#nack重传频率,rtt的倍数 +# nack重传频率,rtt的倍数 +# NACK retransmission frequency (multiple of RTT). nackIntervalRatio=1.0 -#nack包中rtp个数,减小此值可以让nack包响应更灵敏 +# 视频nack包中rtp个数,减小此值可以让nack包响应更灵敏 +# Number of RTP packets in a video NACK packet. Lower values make NACK responses more sensitive. nackRtpSize=8 +# 音频nack包中rtp个数,减小此值可以让nack包响应更灵敏 +# Number of RTP packets in an audio NACK packet. Lower values make NACK responses more sensitive. +nackAudioRtpSize=4 +# 是否尝试过滤 b帧 +# Whether to attempt filtering out B-frames. +bfilter=0 +# 是否优先采用webrtc over tcp模式 +# Whether to prioritize WebRTC over TCP mode. +preferred_tcp=0 [srt] -#srt播放推流、播放超时时间,单位秒 +# srt播放推流、播放超时时间,单位秒 +# Timeout in seconds for SRT stream publishing and playback. timeoutSec=5 -#srt udp服务器监听端口号,所有srt客户端将通过该端口传输srt数据, -#该端口是多线程的,同时支持客户端网络切换导致的连接迁移 +# srt udp服务器监听端口号,所有srt客户端将通过该端口传输srt数据, +# 该端口是多线程的,同时支持客户端网络切换导致的连接迁移 +# SRT UDP server listening port. Handles SRT data for all clients. +# Multi-threaded and supports connection migration during client network switching. port=9000 -#srt 协议中延迟缓存的估算参数,在握手阶段估算rtt ,然后latencyMul*rtt 为最大缓存时长,此参数越大,表示等待重传的时长就越大 +# srt 协议中延迟缓存的估算参数,在握手阶段估算rtt ,然后latencyMul*rtt 为最大缓存时长,此参数越大,表示等待重传的时长就越大 +# SRT protocol delay buffer estimation parameter. Handshake estimated `RTT * latencyMul` sets the maximum buffer duration. Larger values increase wait times for retransmissions. latencyMul=4 -#包缓存的大小 +# 包缓存的大小 +# Packet buffer size. pktBufSize=8192 -#srt udp服务器的密码,为空表示不加密 +# srt udp服务器的密码,为空表示不加密 +# SRT UDP server password (leave empty to disable encryption). passPhrase= - [rtsp] -#rtsp专有鉴权方式是采用base64还是md5方式 +# rtsp专有鉴权方式是采用base64还是md5方式 +# Whether RTSP dedicated authentication uses base64 or md5. authBasic=0 -#rtsp拉流、推流代理是否是直接代理模式 -#直接代理后支持任意编码格式,但是会导致GOP缓存无法定位到I帧,可能会导致开播花屏 -#并且如果是tcp方式拉流,如果rtp大于mtu会导致无法使用udp方式代理 -#假定您的拉流源地址不是264或265或AAC,那么你可以使用直接代理的方式来支持rtsp代理 -#如果你是rtsp推拉流,但是webrtc播放,也建议关闭直接代理模式, -#因为直接代理时,rtp中可能没有sps pps,会导致webrtc无法播放; 另外webrtc也不支持Single NAL Unit Packets类型rtp -#默认开启rtsp直接代理,rtmp由于没有这些问题,是强制开启直接代理的 +# rtsp拉流、推流代理是否是直接代理模式 +# 直接代理后支持任意编码格式,但是会导致GOP缓存无法定位到I帧,可能会导致开播花屏 +# 并且如果是tcp方式拉流,如果rtp大于mtu会导致无法使用udp方式代理 +# 假定您的拉流源地址不是264或265或AAC,那么你可以使用直接代理的方式来支持rtsp代理 +# 如果你是rtsp推拉流,但是webrtc播放,也建议关闭直接代理模式, +# 因为直接代理时,rtp中可能没有sps pps,会导致webrtc无法播放; 另外webrtc也不支持Single NAL Unit Packets类型rtp +# 默认开启rtsp直接代理,rtmp由于没有这些问题,是强制开启直接代理的 +# Whether to enable direct proxy mode for RTSP pulling/publishing. +# Direct proxying supports any codec but bypasses GOP cache I-frame detection, potentially causing initial visual artifacts. +# Furthermore, if pulling via TCP, an RTP payload exceeding the MTU will make UDP proxying unusable. +# Assuming your pull source format is not H264, H265, or AAC, you can use direct proxy mode to support RTSP proxying. +# If you are pulling/pushing via RTSP but playing via WebRTC, it is also recommended to disable direct proxy mode; +# this is because direct proxies may drop SPS/PPS (preventing WebRTC playback), and WebRTC does not support `Single NAL Unit Packets` RTP. +# RTSP direct proxy is enabled by default. RTMP natively enforces direct proxying because it lacks these issues. directProxy=1 -#rtsp必须在此时间内完成握手,否则服务器会断开链接,单位秒 +# rtsp必须在此时间内完成握手,否则服务器会断开链接,单位秒 +# RTSP handshake timeout in seconds. The server drops the connection if not completed. handshakeSecond=15 -#rtsp超时时间,如果该时间内未收到客户端的数据, -#或者tcp发送缓存超过这个时间,则会断开连接,单位秒 +# rtsp超时时间,如果该时间内未收到客户端的数据, +# 或者tcp发送缓存超过这个时间,则会断开连接,单位秒 +# RTSP keep-alive timeout in seconds. Connections drop if no data is received or if the TCP send buffer stalls for this duration. keepAliveSecond=15 -#rtsp服务器监听地址 +# rtsp服务器监听地址 +# RTSP server listening port. port=554 -#rtsps服务器监听地址 +# rtsps服务器监听地址 +# RTSPS server listening port. sslport=0 -#rtsp 转发是否使用低延迟模式,当开启时,不会缓存rtp包,来提高并发,可以降低一帧的延迟 +# rtsp 转发是否使用低延迟模式,当开启时,不会缓存rtp包,来提高并发,可以降低一帧的延迟 +# Whether RTSP forwarding uses low-latency mode. Skips RTP packet caching to improve concurrency and reduce latency by one frame. lowLatency=0 -#强制协商rtp传输方式 (0:TCP,1:UDP,2:MULTICAST,-1:不限制) -#当客户端发起RTSP SETUP的时候如果传输类型和此配置不一致则返回461 Unsupported transport -#迫使客户端重新SETUP并切换到对应协议。目前支持FFMPEG和VLC +# 强制协商rtp传输方式 (0:TCP,1:UDP,2:MULTICAST,-1:不限制) +# 当客户端发起RTSP SETUP的时候如果传输类型和此配置不一致则返回461 Unsupported transport +# 迫使客户端重新SETUP并切换到对应协议。目前支持FFMPEG和VLC +# Force RTP transport negotiation type (0: TCP, 1: UDP, 2: MULTICAST, -1: No limits). +# When the client initiates RTSP SETUP, if the transport type conflicts with this configuration, it returns `461 Unsupported transport`. +# This forces the client to re-SETUP and switch to the corresponding protocol. Currently supports FFmpeg and VLC. rtpTransportType=-1 + [shell] -#调试telnet服务器接受最大bufffer大小 +# 调试telnet服务器接受最大buffer大小 +# Maximum buffer size accepted by the debugging Telnet server. maxReqSize=1024 -#调试telnet服务器监听端口 +# 调试telnet服务器监听端口 +# Debugging Telnet server listening port. port=0 +# onvif搜索用 +# Used for ONVIF search. +[onvif] +port=3702 diff --git a/conf/readme_en.md b/conf/readme_en.md new file mode 100644 index 00000000..81f56f6e --- /dev/null +++ b/conf/readme_en.md @@ -0,0 +1,31 @@ +## Key parameters that affect performance in the configuration file + +### 1. Protocol enable flags (e.g., protocol.enable_hls, protocol.enable_rtsp) + +Controls the protocol conversion flags. Disabling unnecessary protocols will save CPU and memory resources. + +### 2. On-demand protocol flags (e.g., protocol.hls_demand, protocol.rtsp_demand) + +Controls on-demand protocol generation. When both this and the specific protocol are enabled, it saves CPU and memory when there are no active viewers. However, the first viewer will lose the instant playback capability, impacting the initial experience. + +### 3. protocol.paced_sender_ms + +The interval for the smooth sending timer. This helps address playback stuttering caused by irregular data transmission from the source. When enabled, the timer uses data timestamps to pace the transmission, improving the viewing experience. +However, this increases CPU and memory consumption. A shorter timer interval results in higher CPU usage but better smoothness. The recommended interval is between 30 and 100 milliseconds. For optimal results, use this feature in conjunction with setting `protocol.modify_stamp` to 2 (which suppresses timestamp jumps). + +### 4. general.mergeWriteMS + +Enables write coalescing, which reduces the number of system calls and the frequency of data sharing between threads during transmission. This significantly boosts forwarding performance but comes at the cost of increased playback latency and reduced transmission smoothness. + +### 5. rtp_proxy.gop_cache + +Enables the GOP (Group of Pictures) caching feature for the `startSendRtp` cascaded interface, designed to allow instant playback for cascading setups (e.g., GB28181). Note that this setting does not affect the instant playback capability of ZLMediaKit's external live streaming services. +Enabling this option increases memory usage but has a minimal impact on the CPU. We recommend disabling it if you don't use the `startSendRtp` interface. + +### 6. hls.fileBufSize + +Tuning this parameter can improve the disk I/O performance when writing HLS streams. + +### 7. record.fileBufSize + +Tuning this parameter can improve the disk I/O performance when recording MP4 files. diff --git a/dockerfile b/dockerfile index a3e877e6..027f808d 100644 --- a/dockerfile +++ b/dockerfile @@ -1,5 +1,5 @@ -FROM ubuntu:20.04 AS build -ARG MODEL +FROM ubuntu:24.04 AS build +ARG MODEL=Release #shell,rtmp,rtsp,rtsps,http,https,rtp EXPOSE 1935/tcp EXPOSE 554/tcp @@ -27,6 +27,7 @@ RUN apt-get update && \ libssl-dev \ gcc \ g++ \ + python3-dev \ gdb && \ apt-get autoremove -y && \ apt-get clean -y && \ @@ -41,17 +42,17 @@ WORKDIR /opt/media/ZLMediaKit/3rdpart RUN wget https://github.com/cisco/libsrtp/archive/v2.3.0.tar.gz -O libsrtp-2.3.0.tar.gz && \ tar xfv libsrtp-2.3.0.tar.gz && \ mv libsrtp-2.3.0 libsrtp && \ - cd libsrtp && ./configure --enable-openssl && make -j $(nproc) && make install + cd libsrtp && CFLAGS="-fcommon" ./configure --enable-openssl && make -j $(nproc) && make install #RUN git submodule update --init --recursive && \ RUN mkdir -p build release/linux/${MODEL}/ WORKDIR /opt/media/ZLMediaKit/build -RUN cmake -DCMAKE_BUILD_TYPE=${MODEL} -DENABLE_WEBRTC=true -DENABLE_FFMPEG=true -DENABLE_TESTS=false -DENABLE_API=false .. && \ +RUN cmake -DENABLE_PYTHON=true -DCMAKE_BUILD_TYPE=${MODEL} -DENABLE_WEBRTC=true -DENABLE_FFMPEG=true -DENABLE_TESTS=false -DENABLE_API=false .. && \ make -j $(nproc) -FROM ubuntu:20.04 -ARG MODEL +FROM ubuntu:24.04 +ARG MODEL=Release # ADD sources.list /etc/apt/sources.list @@ -67,6 +68,10 @@ RUN apt-get update && \ ffmpeg \ gcc \ g++ \ + python3 \ + python3-dev \ + python3-venv \ + python3-pip \ gdb && \ apt-get autoremove -y && \ apt-get clean -y && \ diff --git a/ext-codec/AAC.cpp b/ext-codec/AAC.cpp index 972aea31..75a38d43 100644 --- a/ext-codec/AAC.cpp +++ b/ext-codec/AAC.cpp @@ -413,6 +413,12 @@ Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { // If aac config information cannot be obtained from sdp, then it cannot be obtained from rtp either, so ignore this Track return nullptr; } + while (aac_cfg_str.size() < 4) { + aac_cfg_str = '0' + aac_cfg_str; + } + if (aac_cfg_str.size() > 4) { + aac_cfg_str = aac_cfg_str.substr(0, 4); + } string aac_cfg; for (size_t i = 0; i < aac_cfg_str.size() / 2; ++i) { unsigned int cfg; diff --git a/ext-codec/AV1.cpp b/ext-codec/AV1.cpp new file mode 100644 index 00000000..ada4241a --- /dev/null +++ b/ext-codec/AV1.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "AV1.h" +#include "AV1Rtp.h" +#include "VpxRtmp.h" +#include "Extension/Factory.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +bool AV1Track::inputFrame(const Frame::Ptr &frame) { + char *dataPtr = frame->data() + frame->prefixSize(); + if (0 == aom_av1_codec_configuration_record_init(&_context, dataPtr, frame->size() - frame->prefixSize())) { + _width = _context.width; + _height = _context.height; + //InfoL << _width << "x" << _height; + } + return VideoTrackImp::inputFrame(frame); +} + +Track::Ptr AV1Track::clone() const { + return std::make_shared(*this); +} + +Buffer::Ptr AV1Track::getExtraData() const { + if (_context.bytes <= 0) + return nullptr; + auto ret = BufferRaw::create(4 + _context.bytes); + ret->setSize(aom_av1_codec_configuration_record_save(&_context, (uint8_t *)ret->data(), ret->getCapacity())); + return ret; +} + +void AV1Track::setExtraData(const uint8_t *data, size_t size) { + if (aom_av1_codec_configuration_record_load(data, size, &_context) > 0) { + _width = _context.width; + _height = _context.height; + } +} + +namespace { + +CodecId getCodec() { + return CodecAV1; +} + +Track::Ptr getTrackByCodecId(int sample_rate, int channels, int sample_bit) { + return std::make_shared(); +} + +Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpDecoderByCodecId() { + return std::make_shared(); +} + +RtmpCodec::Ptr getRtmpEncoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +RtmpCodec::Ptr getRtmpDecoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +Frame::Ptr getFrameFromPtr(const char *data, size_t bytes, uint64_t dts, uint64_t pts) { + return std::make_shared((char *)data, bytes, dts, pts, 0); +} + +} // namespace + +CodecPlugin av1_plugin = { getCodec, + getTrackByCodecId, + getTrackBySdp, + getRtpEncoderByCodecId, + getRtpDecoderByCodecId, + getRtmpEncoderByTrack, + getRtmpDecoderByTrack, + getFrameFromPtr }; + +} // namespace mediakit \ No newline at end of file diff --git a/ext-codec/AV1.h b/ext-codec/AV1.h new file mode 100644 index 00000000..dcf9d1b6 --- /dev/null +++ b/ext-codec/AV1.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_AV1_H +#define ZLMEDIAKIT_AV1_H + +#include "Extension/Frame.h" +#include "Extension/Track.h" +#include "aom-av1.h" +namespace mediakit { + +template +class AV1FrameHelper : public Parent { +public: + friend class FrameImp; + //friend class toolkit::ResourcePool_l; + using Ptr = std::shared_ptr; + + template + AV1FrameHelper(ARGS &&...args) + : Parent(std::forward(args)...) { + this->_codec_id = CodecAV1; + } + + bool keyFrame() const override { + auto ptr = (uint8_t *) this->data() + this->prefixSize(); + return (*ptr & 0x78) >> 3 == 1; + } + bool configFrame() const override { return false; } + bool dropAble() const override { return false; } + bool decodeAble() const override { return true; } +}; + +/// Av1 帧类 +using AV1Frame = AV1FrameHelper; +using AV1FrameNoCacheAble = AV1FrameHelper; + +/** + * AV1视频通道 + */ +class AV1Track : public VideoTrackImp { +public: + using Ptr = std::shared_ptr; + + AV1Track() : VideoTrackImp(CodecAV1) {} + + Track::Ptr clone() const override; + + bool inputFrame(const Frame::Ptr &frame) override; + toolkit::Buffer::Ptr getExtraData() const override; + void setExtraData(const uint8_t *data, size_t size) override; +protected: + aom_av1_t _context {}; +}; + +} // namespace mediakit + +#endif \ No newline at end of file diff --git a/ext-codec/AV1Rtp.cpp b/ext-codec/AV1Rtp.cpp new file mode 100644 index 00000000..7432ee14 --- /dev/null +++ b/ext-codec/AV1Rtp.cpp @@ -0,0 +1,582 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ +#include "AV1.h" +#include "AV1Rtp.h" +#include +#include +#include +#include +#include + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +// AV1 OBU类型定义 +static constexpr int kObuTypeSequenceHeader = 1; +static constexpr int kObuTypeTemporalDelimiter = 2; +static constexpr int kObuTypeTileList = 8; +static constexpr int kObuTypePadding = 15; + +// RTP聚合头中的位定义 +static constexpr uint8_t kObuSizePresentBit = 0b00000010; +static constexpr int kAggregationHeaderSize = 1; +static constexpr int kMaxNumObusToOmitSize = 3; + +// LEB128编码/解码辅助函数 +static size_t writeLeb128(uint64_t value, uint8_t* buffer) { + size_t size = 0; + do { + uint8_t byte = value & 0x7F; + value >>= 7; + if (value != 0) { + byte |= 0x80; + } + buffer[size++] = byte; + } while (value != 0); + return size; +} + +static size_t leb128Size(uint64_t value) { + size_t size = 0; + do { + value >>= 7; + ++size; + } while (value != 0); + return size; +} + +static bool readLeb128(const uint8_t*& data, size_t& remaining, uint64_t& value) { + value = 0; + size_t shift = 0; + + while (remaining > 0 && shift < 56) { + uint8_t byte = *data++; + remaining--; + + value |= (uint64_t(byte & 0x7F) << shift); + shift += 7; + + if ((byte & 0x80) == 0) { + return true; + } + } + + // 兼容性处理:如果到达数据末尾但最后一个字节的MSB仍为1, + // 假设这是leb128编码的结尾 + if (remaining == 0 && shift > 0) { + WarnL << "Tolerating non-standard LEB128 encoding (missing termination bit)"; + return true; + } + + return false; +} + +// OBU辅助函数 +static bool obuHasExtension(uint8_t obu_header) { + return obu_header & 0b00000100; +} + +static bool obuHasSize(uint8_t obu_header) { + return obu_header & kObuSizePresentBit; +} + +static int obuType(uint8_t obu_header) { + return (obu_header & 0b01111000) >> 3; +} + +static int maxFragmentSize(int remaining_bytes) { + if (remaining_bytes <= 1) { + return 0; + } + for (int i = 1; ; ++i) { + if (remaining_bytes < (1 << (7 * i)) + i) { + return remaining_bytes - i; + } + } +} + +////////////////////////////////////////////////////////////////////////// +// AV1RtpEncoder 实现 +////////////////////////////////////////////////////////////////////////// + +AV1RtpEncoder::AV1RtpEncoder() { +} + +std::vector AV1RtpEncoder::parseObus(const uint8_t* data, size_t size) { + std::vector result; + const uint8_t* ptr = data; + size_t remaining = size; + + while (remaining > 0) { + if (remaining < 1) { + WarnL << "Malformed AV1 input: expected OBU header"; + return {}; + } + + ObuInfo obu{}; + obu.header = *ptr++; + remaining--; + obu.has_extension = obuHasExtension(obu.header); + obu.has_size_field = obuHasSize(obu.header); + + if (obu.has_extension) { + if (remaining < 1) { + WarnL << "Malformed AV1 input: expected extension header"; + return {}; + } + obu.extension_header = *ptr++; + remaining--; + } + + uint64_t payload_size = 0; + if (obu.has_size_field) { + if (!readLeb128(ptr, remaining, payload_size)) { + WarnL << "Malformed AV1 input: failed to read OBU size"; + return {}; + } + if (payload_size > remaining) { + WarnL << "Malformed AV1 input: OBU size exceeds remaining data"; + return {}; + } + } else { + payload_size = remaining; + } + + obu.payload_data = ptr; + obu.payload_size = payload_size; + ptr += payload_size; + remaining -= payload_size; + + int type = obuType(obu.header); + if (type != kObuTypeTemporalDelimiter && + type != kObuTypeTileList && + type != kObuTypePadding) { + result.push_back(obu); + } + } + + return result; +} + +uint8_t AV1RtpEncoder::makeAggregationHeader(bool first_obu_is_fragment, + bool last_obu_is_fragment, + int num_obu_elements, + bool starts_new_coded_video_sequence) { + uint8_t header = 0; + + // Z bit: first OBU element is continuation of previous OBU + if (first_obu_is_fragment) { + header |= 0x80; + } + + // Y bit: last OBU element will be continued in next packet + if (last_obu_is_fragment) { + header |= 0x40; + } + + // W field: number of OBU elements (when <= 3) + if (num_obu_elements <= kMaxNumObusToOmitSize) { + header |= (num_obu_elements << 4); + } + + // N bit: beginning of new coded video sequence + if (starts_new_coded_video_sequence) { + header |= 0x08; + } + + return header; +} + +void AV1RtpEncoder::outputRtp(const uint8_t* data, size_t len, bool mark, + uint64_t stamp, uint8_t aggregation_header) { + auto rtp = getRtpInfo().makeRtp(TrackVideo, nullptr, len + kAggregationHeaderSize, mark, stamp); + auto payload = rtp->data() + RtpPacket::kRtpTcpHeaderSize + RtpPacket::kRtpHeaderSize; + + // 写入聚合头 + payload[0] = aggregation_header; + + // 复制数据 + if (len > 0) { + memcpy(payload + kAggregationHeaderSize, data, len); + } + + RtpCodec::inputRtp(std::move(rtp), false); +} + +bool AV1RtpEncoder::inputFrame(const Frame::Ptr &frame) { + auto ptr = frame->data() + frame->prefixSize(); + auto size = frame->size() - frame->prefixSize(); + + if (size == 0) { + return false; + } + + // 解析OBU + auto obus = parseObus((const uint8_t*)ptr, size); + if (obus.empty()) { + return false; + } + + // 检查是否包含序列头(关键帧标志) + bool has_sequence_header = false; + for (const auto& obu : obus) { + int type = obuType(obu.header); + if (type == kObuTypeSequenceHeader) { + has_sequence_header = true; + _got_key_frame = true; + break; + } + } + + // 如果还没有收到过关键帧,且当前帧不是关键帧,则丢弃 + if (!_got_key_frame && !has_sequence_header) { + DebugL << "Dropping AV1 frame before first keyframe"; + return false; + } + + size_t max_payload_size = getRtpInfo().getMaxSize() - kAggregationHeaderSize; + if (max_payload_size == 0) { + WarnL << "Invalid RTP max payload size for AV1"; + return false; + } + + for (size_t i = 0; i < obus.size(); ++i) { + const auto& obu = obus[i]; + bool is_first_obu = (i == 0); + bool is_last_obu = (i == obus.size() - 1); + if (!sendObu(obu, is_first_obu, is_last_obu, + has_sequence_header && is_first_obu, frame->pts(), max_payload_size)) { + return false; + } + } + + return true; +} + +bool AV1RtpEncoder::sendObu(const ObuInfo& obu, + bool is_first_obu, + bool is_last_obu, + bool starts_new_sequence, + uint64_t stamp, + size_t max_payload_size) { + std::vector obu_bytes; + obu_bytes.reserve(1 + (obu.has_extension ? 1 : 0) + obu.payload_size); + obu_bytes.push_back(obu.header & ~kObuSizePresentBit); + if (obu.has_extension) { + obu_bytes.push_back(obu.extension_header); + } + if (obu.payload_size > 0) { + obu_bytes.insert(obu_bytes.end(), obu.payload_data, obu.payload_data + obu.payload_size); + } + + size_t offset = 0; + bool first_fragment = true; + while (offset < obu_bytes.size()) { + size_t fragment_size = std::min(max_payload_size, obu_bytes.size() - offset); + bool last_fragment = (offset + fragment_size) == obu_bytes.size(); + uint8_t agg_header = makeAggregationHeader( + !first_fragment, + !last_fragment, + 1, + first_fragment && starts_new_sequence + ); + + bool mark = last_fragment && is_last_obu; + outputRtp(obu_bytes.data() + offset, fragment_size, mark, stamp, agg_header); + + offset += fragment_size; + first_fragment = false; + } + + return true; +} + +////////////////////////////////////////////////////////////////////////// +// AV1RtpDecoder 实现 +////////////////////////////////////////////////////////////////////////// + +AV1RtpDecoder::AV1RtpDecoder() { + obtainFrame(); +} + +void AV1RtpDecoder::obtainFrame() { + _frame = FrameImp::create(); +} + +AV1RtpDecoder::AggregationHeader AV1RtpDecoder::parseAggregationHeader(uint8_t header) { + AggregationHeader agg; + agg.first_obu_is_fragment = (header & 0x80) != 0; + agg.last_obu_is_fragment = (header & 0x40) != 0; + agg.num_obu_elements = (header & 0x30) >> 4; + agg.starts_new_coded_video_sequence = (header & 0x08) != 0; + return agg; +} + +bool AV1RtpDecoder::inputRtp(const RtpPacket::Ptr &rtp, bool key_pos) { + auto payload_size = rtp->getPayloadSize(); + if (payload_size < kAggregationHeaderSize) { + return false; + } + + uint32_t ssrc = rtp->getSSRC(); + if (!_has_last_ssrc || _last_ssrc != ssrc) { + resetState(); + _last_ssrc = ssrc; + _has_last_ssrc = true; + } + + auto stamp = rtp->getStampMS(); + auto payload = rtp->getPayload(); + auto seq = rtp->getSeq(); + + // 解析聚合头 + auto agg_header = parseAggregationHeader(payload[0]); + + const uint8_t* data = payload + kAggregationHeaderSize; + size_t remaining = payload_size - kAggregationHeaderSize; + + // InfoL << "RTP seq=" << seq << ", Z=" << agg_header.first_obu_is_fragment + // << ", Y=" << agg_header.last_obu_is_fragment + // << ", W=" << agg_header.num_obu_elements + // << ", N=" << agg_header.starts_new_coded_video_sequence + // << ", payload_size=" << remaining; + + // if (remaining > 0) { + // std::ostringstream hex_stream; + // for (size_t i = 0; i < std::min(remaining, size_t(16)); ++i) { + // hex_stream << std::hex << std::setw(2) << std::setfill('0') << (int)data[i] << " "; + // } + // InfoL << "RTP payload hex: " << hex_stream.str(); + // } + + // 如果开始新的编码视频序列,清理之前的状态 + if (agg_header.starts_new_coded_video_sequence) { + InfoL << "Starting new coded video sequence"; + resetState(); + obtainFrame(); + } + + if (_has_last_seq) { + uint16_t expected = _last_seq + 1; + if (seq != expected && _assembling_fragment) { + WarnL << "RTP seq gap while assembling fragment, expected=" << expected + << " got=" << seq << ", dropping incomplete OBU"; + _fragment_buffer.clear(); + _assembling_fragment = false; + } + } + _last_seq = seq; + _has_last_seq = true; + + if (!processPayload(agg_header, data, remaining)) { + resetState(); + obtainFrame(); + return false; + } + + bool marker = rtp->getHeader()->mark; + if (marker) { + if (_assembling_fragment) { + WarnL << "Marker bit set while awaiting fragment continuation"; + _fragment_buffer.clear(); + _assembling_fragment = false; + } + _last_dts = stamp; + if (!_received_keyframe) { + WarnL << "AV1 RTP packet before keyframe, dropping"; + _frame->_buffer.clear(); + obtainFrame(); + return false; + } + flushFrame(stamp); + return true; + } + + _last_dts = stamp; + return false; +} + +bool AV1RtpDecoder::processPayload(const AggregationHeader& agg_header, + const uint8_t* data, + size_t remaining) { + size_t element_index = 0; + int expected_elements = agg_header.num_obu_elements; + + while (remaining > 0) { + uint64_t element_size = 0; + bool has_size = (expected_elements == 0) || (static_cast(element_index) < expected_elements - 1); + if (has_size) { + if (!readLeb128(data, remaining, element_size)) { + WarnL << "Failed to read OBU element size, trying fallback parsing"; + // 兼容性回退:如果leb128解析失败,尝试直接使用剩余字节数 + element_size = remaining; + } else if (element_size > remaining) { + WarnL << "OBU element size (" << element_size << ") exceeds remaining payload (" + << remaining << "), using remaining size"; + element_size = remaining; + } + } else { + element_size = remaining; + } + + std::vector element_bytes; + element_bytes.reserve(element_size); + if (element_size > 0) { + element_bytes.insert(element_bytes.end(), data, data + element_size); + data += element_size; + remaining -= element_size; + } + + bool is_first = element_index == 0; + bool is_last = (remaining == 0); + + if (is_first && agg_header.first_obu_is_fragment) { + if (_fragment_buffer.empty()) { + WarnL << "Unexpected fragment continuation in AV1 RTP packet"; + return false; + } + _fragment_buffer.insert(_fragment_buffer.end(), element_bytes.begin(), element_bytes.end()); + } else { + if (_assembling_fragment && !_fragment_buffer.empty()) { + WarnL << "Previous fragment never completed, discarding"; + return false; + } + _fragment_buffer = std::move(element_bytes); + } + + bool will_continue = is_last && agg_header.last_obu_is_fragment; + if (will_continue) { + _assembling_fragment = true; + } else { + if (!emitObu(_fragment_buffer.data(), _fragment_buffer.size())) { + return false; + } + _fragment_buffer.clear(); + _assembling_fragment = false; + } + + ++element_index; + } + + if (expected_elements > 0 && static_cast(element_index) != expected_elements) { + WarnL << "Mismatch between W field (" << expected_elements + << ") and parsed OBU elements (" << element_index + << "), tolerating for compatibility"; + // 不返回false,继续处理以提高兼容性 + } + + return true; +} + +bool AV1RtpDecoder::emitObu(const uint8_t* data, size_t size) { + if (size == 0) { + return true; + } + + if (size < 1) { + WarnL << "Empty OBU fragment"; + return false; + } + + uint8_t obu_header = data[0]; + size_t header_size = 1; + + // 检查OBU头部是否已经包含size bit + bool already_has_size = obuHasSize(obu_header); + + // 如果RTP包中的OBU已经包含size字段,需要特殊处理 + if (already_has_size) { + //WarnL << "RTP OBU contains size field"; + + // 跳过extension header处理 + if (obuHasExtension(obu_header)) { + if (size < 2) { + WarnL << "OBU with extension flag but insufficient data"; + return false; + } + header_size = 2; + } + + // 读取原始的size字段 + const uint8_t* ptr = data + header_size; + size_t remaining = size - header_size; + uint64_t original_size = 0; + + if (!readLeb128(ptr, remaining, original_size)) { + WarnL << "Failed to read original OBU size field"; + return false; + } + + if (original_size != remaining) { + WarnL << "OBU size mismatch in RTP packet, original_size=" << original_size + << " remaining=" << remaining; + } + + // 直接拷贝完整的OBU(包括已有的size字段) + _frame->_buffer.append((char*)data, size); + } else { + // 标准情况:RTP包中的OBU没有size字段,需要我们添加 + + // 写入带size bit的OBU头部 + _frame->_buffer.push_back(obu_header | kObuSizePresentBit); + + if (obuHasExtension(obu_header)) { + if (size < 2) { + WarnL << "OBU with extension flag but insufficient data"; + return false; + } + _frame->_buffer.push_back(data[1]); + header_size = 2; + } + + if (size < header_size) { + WarnL << "Invalid OBU size"; + return false; + } + + // 计算payload大小并写入leb128编码的size字段 + uint64_t payload_size = size - header_size; + uint8_t size_bytes[8]; + size_t size_len = writeLeb128(payload_size, size_bytes); + _frame->_buffer.append((char*)size_bytes, size_len); + + // 拷贝payload数据 + if (payload_size > 0) { + _frame->_buffer.append((char*)data + header_size, payload_size); + } + } + + if (obuType(obu_header) == kObuTypeSequenceHeader) { + _received_keyframe = true; + } + + return true; +} + +void AV1RtpDecoder::flushFrame(uint64_t stamp) { + if (_frame->_buffer.empty()) { + return; + } + _frame->_dts = stamp; + _frame->_pts = stamp; + RtpCodec::inputFrame(_frame); + obtainFrame(); +} + +void AV1RtpDecoder::resetState() { + _fragment_buffer.clear(); + _assembling_fragment = false; + _has_last_seq = false; + _received_keyframe = false; +} + +} // namespace mediakit diff --git a/ext-codec/AV1Rtp.h b/ext-codec/AV1Rtp.h new file mode 100644 index 00000000..ddd02141 --- /dev/null +++ b/ext-codec/AV1Rtp.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_AV1RTP_H +#define ZLMEDIAKIT_AV1RTP_H + +#include "Rtsp/RtpCodec.h" +#include "Extension/Frame.h" +#include "Extension/CommonRtp.h" + +namespace mediakit { + +/** + * AV1 RTP编码器 + */ +class AV1RtpEncoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + AV1RtpEncoder(); + ~AV1RtpEncoder() override = default; + + bool inputFrame(const Frame::Ptr &frame) override; + +private: + // AV1 OBU信息 + struct ObuInfo { + uint8_t header; + uint8_t extension_header; + const uint8_t* payload_data; + size_t payload_size; + bool has_extension; + bool has_size_field; + }; + + std::vector parseObus(const uint8_t* data, size_t size); + void outputRtp(const uint8_t* data, size_t len, bool mark, uint64_t stamp, uint8_t aggregation_header); + uint8_t makeAggregationHeader(bool first_obu_is_fragment, bool last_obu_is_fragment, + int num_obu_elements, bool starts_new_coded_video_sequence); + bool sendObu(const ObuInfo& obu, bool is_first_obu, bool is_last_obu, + bool starts_new_sequence, uint64_t stamp, size_t max_payload_size); + +private: + bool _got_key_frame = false; +}; + +/** + * AV1 RTP解码器 + */ +class AV1RtpDecoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + AV1RtpDecoder(); + ~AV1RtpDecoder() override = default; + + bool inputRtp(const RtpPacket::Ptr &rtp, bool key_pos = false) override; + +private: + struct AggregationHeader { + bool first_obu_is_fragment; // Z bit + bool last_obu_is_fragment; // Y bit + int num_obu_elements; // W field (0 = any number) + bool starts_new_coded_video_sequence; // N bit + }; + + AggregationHeader parseAggregationHeader(uint8_t header); + void obtainFrame(); + bool emitObu(const uint8_t* data, size_t size); + bool processPayload(const AggregationHeader& agg_header, const uint8_t* data, + size_t remaining); + void flushFrame(uint64_t stamp); + void resetState(); + +private: + uint64_t _last_dts = 0; + FrameImp::Ptr _frame; + std::vector _fragment_buffer; + bool _assembling_fragment = false; + bool _received_keyframe = false; + bool _has_last_seq = false; + uint16_t _last_seq = 0; + bool _has_last_ssrc = false; + uint32_t _last_ssrc = 0; +}; + +}//namespace mediakit +#endif //ZLMEDIAKIT_AV1RTP_H diff --git a/ext-codec/CMakeLists.txt b/ext-codec/CMakeLists.txt index 23190fb8..5fefad0c 100644 --- a/ext-codec/CMakeLists.txt +++ b/ext-codec/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/ext-codec/G711.cpp b/ext-codec/G711.cpp index a456f8bc..c4bf255f 100644 --- a/ext-codec/G711.cpp +++ b/ext-codec/G711.cpp @@ -13,18 +13,35 @@ #include "Extension/Factory.h" #include "Extension/CommonRtp.h" #include "Extension/CommonRtmp.h" - +#include "riff-acm.h" using namespace std; using namespace toolkit; namespace mediakit { -Track::Ptr G711Track::clone() const { - return std::make_shared(*this); +Buffer::Ptr G711Track::getExtraData() const { + struct wave_format_t wav {}; + wav.wFormatTag = getCodecId() == CodecG711A ? WAVE_FORMAT_ALAW : WAVE_FORMAT_MULAW; + wav.nChannels = getAudioChannel(); + wav.nSamplesPerSec = getAudioSampleRate(); + wav.nAvgBytesPerSec = 8000; + wav.nBlockAlign = 1; + wav.wBitsPerSample = 8; + auto buff = BufferRaw::create(18 + wav.cbSize); + wave_format_save(&wav, (uint8_t*)buff->data(), buff->size()); + return buff; } -Sdp::Ptr G711Track::getSdp(uint8_t payload_type) const { - return std::make_shared(payload_type, *this); +void G711Track::setExtraData(const uint8_t *data, size_t size) { + struct wave_format_t wav; + if (wave_format_load(data, size, &wav) > 0) { + // Successfully parsed Opus header + _sample_rate = wav.nSamplesPerSec; + _channels = wav.nChannels; + _codecid = (wav.wFormatTag == WAVE_FORMAT_ALAW) ? CodecG711A : CodecG711U; + } else { + WarnL << "Failed to parse G711 extra data"; + } } namespace { diff --git a/ext-codec/G711.h b/ext-codec/G711.h index 3f1e9f64..aadad0d8 100644 --- a/ext-codec/G711.h +++ b/ext-codec/G711.h @@ -18,19 +18,16 @@ namespace mediakit{ /** * G711音频通道 - * G711 audio channel - - - * [AUTO-TRANSLATED:57f8bc08] */ class G711Track : public AudioTrackImp{ public: using Ptr = std::shared_ptr; G711Track(CodecId codecId, int sample_rate = 8000, int channels = 1, int sample_bit = 16) : AudioTrackImp(codecId, sample_rate, channels, sample_bit) {} + toolkit::Buffer::Ptr getExtraData() const override; + void setExtraData(const uint8_t *data, size_t size) override; private: - Sdp::Ptr getSdp(uint8_t payload_type) const override; - Track::Ptr clone() const override; + Track::Ptr clone() const override { return std::make_shared(*this); } }; }//namespace mediakit diff --git a/ext-codec/G711Rtp.cpp b/ext-codec/G711Rtp.cpp index d60e7362..198a1103 100644 --- a/ext-codec/G711Rtp.cpp +++ b/ext-codec/G711Rtp.cpp @@ -38,7 +38,8 @@ bool G711RtpEncoder::inputFrame(const Frame::Ptr &frame) { _buffer.append(ptr, size); while (_buffer.size() >= _pkt_bytes) { - RtpCodec::inputRtp(getRtpInfo().makeRtp(TrackAudio, _buffer.data(), _pkt_bytes, false, in_pts), false); + auto tmp = (in_pts+_pkt_dur_ms-1)/_pkt_dur_ms*_pkt_dur_ms; + RtpCodec::inputRtp(getRtpInfo().makeRtp(TrackAudio, _buffer.data(), _pkt_bytes, false, tmp), false); in_pts += _pkt_dur_ms; _buffer.erase(0, _pkt_bytes); } diff --git a/ext-codec/H264.cpp b/ext-codec/H264.cpp index 38c1d9bc..28a6d9d0 100644 --- a/ext-codec/H264.cpp +++ b/ext-codec/H264.cpp @@ -153,7 +153,6 @@ bool H264Track::ready() const { bool H264Track::inputFrame(const Frame::Ptr &frame) { using H264FrameInternal = FrameInternal; int type = H264_TYPE(frame->data()[frame->prefixSize()]); - if ((type == H264Frame::NAL_B_P || type == H264Frame::NAL_IDR) && ready()) { return inputFrame_l(frame); } @@ -263,6 +262,10 @@ Track::Ptr H264Track::clone() const { bool H264Track::inputFrame_l(const Frame::Ptr &frame) { int type = H264_TYPE(frame->data()[frame->prefixSize()]); + if (type == H264Frame::NAL_AUD) { + // AUD帧丢弃 + return false; + } bool ret = true; switch (type) { case H264Frame::NAL_SPS: { @@ -388,7 +391,7 @@ Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { // If there is no sps/pps in the sdp, then it may be possible to recover the sps/pps in the subsequent rtp return std::make_shared(); } - return std::make_shared(sps, pps, 0, 0); + return std::make_shared(sps, pps, prefixSize(sps.data(), sps.size()), prefixSize(pps.data(), pps.size())); } RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { diff --git a/ext-codec/H265.cpp b/ext-codec/H265.cpp index 0d0299b3..906973a0 100644 --- a/ext-codec/H265.cpp +++ b/ext-codec/H265.cpp @@ -160,6 +160,7 @@ toolkit::Buffer::Ptr H265Track::getExtraData() const { WarnL << "生成H265 extra_data 失败"; return nullptr; } + extra_data.resize(extra_data_size); return std::make_shared(std::move(extra_data)); #else WarnL << "请开启MP4相关功能并使能\"ENABLE_MP4\",否则对H265的支持不完善"; @@ -215,6 +216,108 @@ void H265Track::insertConfigFrame(const Frame::Ptr &frame) { } } +class BitReader { +public: + BitReader(const uint8_t* data, size_t size) : _data(data), _size(size), _bitPos(0) {} + + uint32_t readBits(int n) { + uint32_t result = 0; + for (int i = 0; i < n; i++) { + if (_bitPos >= _size * 8) throw std::runtime_error("Out of range"); + int bytePos = _bitPos / 8; + int bitOffset = 7 - (_bitPos % 8); + result = (result << 1) | ((_data[bytePos] >> bitOffset) & 0x01); + _bitPos++; + } + return result; + } + + void skipBits(int n) { + _bitPos += n; + if (_bitPos > _size * 8) throw std::runtime_error("Skip out of range"); + } + +private: + const uint8_t* _data; + size_t _size; + size_t _bitPos; +}; + +struct HevcProfileInfo { + int profile_id = -1; // profile-id + int level_id = -1; // level-id + int tier_flag = -1; // tier-flag +}; + +// 移除 00 00 03 防竞争字节 +std::vector removeEmulationPrevention(const uint8_t *data, size_t size) { + std::vector out; + out.reserve(size); + for (size_t i = 0; i < size; i++) { + if (i + 2 < size && data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03) { + out.push_back(0x00); + out.push_back(0x00); + i += 2; // skip 0x00 0x00 0x03 + } else { + out.push_back(data[i]); + } + } + return out; +} + +// 从 VPS 或 SPS 里提取 profile/level/tier 信息 +HevcProfileInfo parse_hevc_profile_tier_level(const uint8_t *nalu, size_t size) { + // 去掉起始码 (00 00 01 或 00 00 00 01) + size_t offset = 0; + if (size > 4 && nalu[0] == 0x00 && nalu[1] == 0x00) { + if (nalu[2] == 0x01) + offset = 3; + else if (nalu[2] == 0x00 && nalu[3] == 0x01) + offset = 4; + } + + auto rbsp = removeEmulationPrevention(nalu + offset, size - offset); + BitReader br(rbsp.data(), rbsp.size()); + + // ---- NALU header ---- + br.skipBits(1 + 6 + 6 + 3); // forbidden_zero_bit + nal_unit_type + nuh_layer_id + nuh_temporal_id_plus1 + + // VPS 和 SPS 都包含 profile_tier_level() + // 先解析最少需要的部分 + + // vps_video_parameter_set_id 或 sps_video_parameter_set_id (略过) + br.readBits(4); + + // sps 里还有 sps_max_sub_layers_minus1 + uint32_t max_sub_layers_minus1 = br.readBits(3); + // temporal_id_nesting_flag + br.readBits(1); + + // ---- profile_tier_level ---- + HevcProfileInfo info; + uint32_t profile_space = br.readBits(2); // general_profile_space + info.tier_flag = br.readBits(1); // general_tier_flag + info.profile_id = br.readBits(5); // general_profile_idc + + // general_profile_compatibility_flag[32] + for (int i = 0; i < 32; i++) + br.readBits(1); + + // general_progressive_source_flag 等 (跳过) + br.readBits(1); // progressive_source_flag + br.readBits(1); // interlaced_source_flag + br.readBits(1); // non_packed_constraint_flag + br.readBits(1); // frame_only_constraint_flag + + // general_reserved_zero_44bits + br.skipBits(44); + + // general_level_idc (8 bits) + info.level_id = br.readBits(8); + + return info; +} + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /** @@ -247,7 +350,9 @@ public: _printer << "b=AS:" << bitrate << "\r\n"; } _printer << "a=rtpmap:" << payload_type << " " << getCodecName(CodecH265) << "/" << 90000 << "\r\n"; - _printer << "a=fmtp:" << payload_type << " "; + + auto info = parse_hevc_profile_tier_level((uint8_t *)strSPS.data(), strSPS.size()); + _printer << "a=fmtp:" << payload_type << " level-id=" << info.level_id << "; profile-id=" << info.profile_id << "; tier-flag=" << info.tier_flag << "; "; _printer << "sprop-vps="; _printer << encodeBase64(strVPS) << "; "; _printer << "sprop-sps="; @@ -287,7 +392,10 @@ Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { // If there is no sps/pps in the sdp, then it may be possible to recover sps/pps from the subsequent rtp return std::make_shared(); } - return std::make_shared(vps, sps, pps, 0, 0, 0); + return std::make_shared(vps, sps, pps, + prefixSize(vps.data(), vps.size()), + prefixSize(sps.data(), sps.size()), + prefixSize(pps.data(), pps.size())); } RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { diff --git a/ext-codec/H265Rtp.cpp b/ext-codec/H265Rtp.cpp index 8b3a3501..24e90d67 100644 --- a/ext-codec/H265Rtp.cpp +++ b/ext-codec/H265Rtp.cpp @@ -268,12 +268,12 @@ void H265RtpEncoder::packRtpFu(const char *ptr, size_t len, uint64_t pts, bool i auto nal_type = H265_TYPE(ptr[0]); //获取NALU的5bit 帧类型 unsigned char s_e_flags; bool fu_start = true; - bool mark_bit = false; + bool fu_end = false; size_t offset = 2; - while (!mark_bit) { + while (!fu_end) { if (len <= offset + max_size) { // FU end - mark_bit = true; + fu_end = true; max_size = len - offset; s_e_flags = (1 << 6) | nal_type; } else if (fu_start) { @@ -287,7 +287,9 @@ void H265RtpEncoder::packRtpFu(const char *ptr, size_t len, uint64_t pts, bool i { // 传入nullptr先不做payload的内存拷贝 [AUTO-TRANSLATED:7ed49f0a] // Pass in nullptr first, do not copy the payload memory - auto rtp = getRtpInfo().makeRtp(TrackVideo, nullptr, max_size + 3, mark_bit, pts); + // 只有FU的最后一个分片且整个帧需要设置mark时才设置mark位 + bool mark_bit = fu_end && is_mark; + auto rtp = getRtpInfo().makeRtp(TrackVideo, nullptr, max_size + 3, mark_bit && is_mark, pts); //yzw 帧(不是NALU,多TILE时一帧有多个NALU)最后一个rtp才设置mark位 // rtp payload 负载部分 [AUTO-TRANSLATED:03a5ef9b] // rtp payload load part uint8_t *payload = rtp->getPayload(); diff --git a/ext-codec/JPEGRtp.cpp b/ext-codec/JPEGRtp.cpp index 648d2d02..61fe0ec3 100644 --- a/ext-codec/JPEGRtp.cpp +++ b/ext-codec/JPEGRtp.cpp @@ -133,7 +133,7 @@ static inline void bytestream2_put_be16(PutByteContext *p, uint16_t value) { } } -static inline void bytestream2_put_be24(PutByteContext *p, uint16_t value) { +static inline void bytestream2_put_be24(PutByteContext *p, uint32_t value) { if (!p->eof && (p->buffer_end - p->buffer >= 2)) { p->buffer[0] = value >> 16; p->buffer[1] = value >> 8; diff --git a/ext-codec/MP2A.cpp b/ext-codec/MP2A.cpp new file mode 100644 index 00000000..31029842 --- /dev/null +++ b/ext-codec/MP2A.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "MP2A.h" +#include "MP2ARtp.h" +#include "Extension/Factory.h" +#include "Extension/CommonRtmp.h" +#include "Rtsp/Rtsp.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +// ======================== MpegAudioFrameInfo ======================== + +// MPEG Audio 版本表 +// MPEG Audio version table +// Index: version_bits (2 bits from header) +// 00 = MPEG 2.5, 01 = reserved, 10 = MPEG 2, 11 = MPEG 1 +static const int s_mpeg_version[] = { 3, 0, 2, 1 }; // 3=MPEG2.5, 0=reserved, 2=MPEG2, 1=MPEG1 + +// Layer 表: 00=reserved, 01=III, 10=II, 11=I +static const int s_mpeg_layer[] = { 0, 3, 2, 1 }; + +// MPEG-1 比特率表 (kbps) +// bitrate_index: 0-15, layer: 1-3 +static const int s_bitrate_mpeg1[][16] = { + // Layer I + { 0, 32, 64, 96, 128, 160, 192, 224, 256, 288, 320, 352, 384, 416, 448, 0 }, + // Layer II + { 0, 32, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 384, 0 }, + // Layer III + { 0, 32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 256, 320, 0 }, +}; + +// MPEG-2/2.5 比特率表 (kbps) +static const int s_bitrate_mpeg2[][16] = { + // Layer I + { 0, 32, 48, 56, 64, 80, 96, 112, 128, 144, 160, 176, 192, 224, 256, 0 }, + // Layer II / III + { 0, 8, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128, 144, 160, 0 }, +}; + +// 采样率表 (Hz) +// Index: [version_index][samplerate_index] +static const int s_sample_rate[][4] = { + { 44100, 48000, 32000, 0 }, // MPEG-1 + { 22050, 24000, 16000, 0 }, // MPEG-2 + { 11025, 12000, 8000, 0 }, // MPEG-2.5 +}; + +bool MpegAudioFrameInfo::parse(const uint8_t *data, size_t size, MpegAudioFrameInfo &info) { + if (size < 4) { + return false; + } + // 检查同步字 0xFFE0 (11 bits all 1) + if (data[0] != 0xFF || (data[1] & 0xE0) != 0xE0) { + return false; + } + + int version_bits = (data[1] >> 3) & 0x03; + int layer_bits = (data[1] >> 1) & 0x03; + // int protection = !(data[1] & 0x01); + int bitrate_index = (data[2] >> 4) & 0x0F; + int samplerate_index = (data[2] >> 2) & 0x03; + int padding = (data[2] >> 1) & 0x01; + int channel_mode = (data[3] >> 6) & 0x03; + + int ver = s_mpeg_version[version_bits]; + int layer = s_mpeg_layer[layer_bits]; + + if (ver == 0 || layer == 0 || samplerate_index == 3 || bitrate_index == 0 || bitrate_index == 15) { + return false; + } + + int ver_index = ver - 1; // 0=MPEG1, 1=MPEG2, 2=MPEG2.5 + int sr = s_sample_rate[ver_index][samplerate_index]; + if (sr == 0) { + return false; + } + + int bitrate = 0; + if (ver == 1) { + // MPEG-1 + bitrate = s_bitrate_mpeg1[layer - 1][bitrate_index]; + } else { + // MPEG-2 / MPEG-2.5 + if (layer == 1) { + bitrate = s_bitrate_mpeg2[0][bitrate_index]; + } else { + bitrate = s_bitrate_mpeg2[1][bitrate_index]; + } + } + + info.version = ver; + info.layer = layer; + info.bitrate = bitrate; + info.sample_rate = sr; + info.channels = (channel_mode == 3) ? 1 : 2; // 3=mono, 其他=stereo + + // 计算每帧的采样数和帧大小 + if (layer == 1) { + // Layer I: 384 samples + info.samples_per_frame = 384; + info.frame_size = (12 * bitrate * 1000 / sr + padding) * 4; + } else if (layer == 2) { + // Layer II: 1152 samples + info.samples_per_frame = 1152; + info.frame_size = 144 * bitrate * 1000 / sr + padding; + } else { + // Layer III + if (ver == 1) { + info.samples_per_frame = 1152; + info.frame_size = 144 * bitrate * 1000 / sr + padding; + } else { + info.samples_per_frame = 576; + info.frame_size = 72 * bitrate * 1000 / sr + padding; + } + } + return true; +} + +// ======================== MP2ATrack ======================== + +bool MP2ATrack::inputFrame(const Frame::Ptr &frame) { + if (!_info_parsed) { + auto data = (const uint8_t *)frame->data() + frame->prefixSize(); + auto size = frame->size() - frame->prefixSize(); + MpegAudioFrameInfo info; + if (MpegAudioFrameInfo::parse(data, size, info)) { + _sample_rate = info.sample_rate; + _channels = info.channels; + _info_parsed = true; + } + } + return AudioTrackImp::inputFrame(frame); +} + +Sdp::Ptr MP2ATrack::getSdp(uint8_t pt) const { + // RFC 2250/3551: MPA 的 RTP 时钟频率固定为 90000,而不是音频采样率 + // RFC 2250/3551: MPA RTP clock rate is fixed at 90000, not the audio sample rate + class MP2ASdp : public Sdp { + public: + // 注意:Sdp 基类构造必须传入 90000 作为 sample_rate + MP2ASdp(uint8_t payload_type, int channels, int bitrate) + : Sdp(90000, payload_type) { + _printer << "m=audio 0 RTP/AVP " << (int)payload_type << "\r\n"; + if (bitrate) { + _printer << "b=AS:" << bitrate << "\r\n"; + } + _printer << "a=rtpmap:" << (int)payload_type << " MPA/90000/" << channels << "\r\n"; + } + std::string getSdp() const override { return _printer; } + + private: + toolkit::_StrPrinter _printer; + }; + return std::make_shared(pt, getAudioChannel(), getBitRate() >> 10); +} + +Track::Ptr MP2ATrack::clone() const { + return std::make_shared(*this); +} + +namespace { + +CodecId getCodec() { + return CodecMP2A; +} + +Track::Ptr getTrackByCodecId(int sample_rate, int channels, int sample_bit) { + return std::make_shared(sample_rate, channels); +} + +Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { + return std::make_shared(track->_samplerate, track->_channel); +} + +RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpDecoderByCodecId() { + return std::make_shared(); +} + +RtmpCodec::Ptr getRtmpEncoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +RtmpCodec::Ptr getRtmpDecoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +Frame::Ptr getFrameFromPtr(const char *data, size_t bytes, uint64_t dts, uint64_t pts) { + return std::make_shared((char *)data, bytes, dts, pts); +} + +} // namespace + +CodecPlugin mp2a_plugin = { getCodec, + getTrackByCodecId, + getTrackBySdp, + getRtpEncoderByCodecId, + getRtpDecoderByCodecId, + getRtmpEncoderByTrack, + getRtmpDecoderByTrack, + getFrameFromPtr }; + +} // namespace mediakit diff --git a/ext-codec/MP2A.h b/ext-codec/MP2A.h new file mode 100644 index 00000000..a3c6c1ec --- /dev/null +++ b/ext-codec/MP2A.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_MP2A_H +#define ZLMEDIAKIT_MP2A_H + +#include "Extension/Frame.h" +#include "Extension/Track.h" + +namespace mediakit { + +/** + * MPEG-1/2 Audio (Layer I/II) 帧辅助类模板 + * MPEG-1/2 Audio (Layer I/II) frame helper class template + */ +template +class MP2AFrameHelper : public Parent { +public: + using Ptr = std::shared_ptr; + + template + MP2AFrameHelper(ARGS &&...args) + : Parent(std::forward(args)...) { + this->_codec_id = CodecMP2A; + } + + bool keyFrame() const override { return false; } + bool configFrame() const override { return false; } +}; + +/// MPEG-1/2 Audio 帧类 +using MP2AFrame = MP2AFrameHelper; +using MP2AFrameNoCacheAble = MP2AFrameHelper; + +// MPEG Audio 帧头解析工具 +// MPEG Audio frame header parsing utility +struct MpegAudioFrameInfo { + int version = 0; // 1: MPEG-1, 2: MPEG-2, 3: MPEG-2.5 + int layer = 0; // 1: Layer I, 2: Layer II, 3: Layer III + int bitrate = 0; // kbps + int sample_rate = 0; // Hz + int channels = 0; // 1: mono, 2: stereo + int frame_size = 0; // bytes per frame + int samples_per_frame = 0; + + /** + * 从 MPEG Audio sync word 解析帧头信息 + * Parse frame header info from MPEG Audio sync word + * @param data 数据指针,至少4字节 + * @param size 数据大小 + * @return 是否解析成功 + */ + static bool parse(const uint8_t *data, size_t size, MpegAudioFrameInfo &info); +}; + +/** + * MPEG-1/2 Audio (Layer I/II) Track + * 对应 CodecMP2A + */ +class MP2ATrack : public AudioTrackImp { +public: + using Ptr = std::shared_ptr; + + MP2ATrack(int sample_rate = 44100, int channels = 2) + : AudioTrackImp(CodecMP2A, sample_rate, channels, 16) {} + + bool inputFrame(const Frame::Ptr &frame) override; + +private: + /** + * RFC 2250/3551 规定 MPA 的 RTP 时钟频率固定为 90000 + * RFC 2250/3551 specifies MPA RTP clock rate is fixed at 90000 + */ + Sdp::Ptr getSdp(uint8_t payload_type) const override; + Track::Ptr clone() const override; + +private: + bool _info_parsed = false; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_MP2A_H diff --git a/ext-codec/MP2ARtp.cpp b/ext-codec/MP2ARtp.cpp new file mode 100644 index 00000000..63a56874 --- /dev/null +++ b/ext-codec/MP2ARtp.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "MP2ARtp.h" + +namespace mediakit { + +// ======================== MP2ARtpEncoder ======================== + +void MP2ARtpEncoder::outputRtp(const char *data, size_t len, size_t frag_offset, bool mark, uint64_t stamp) { + // RFC 2250 Section 3.5: + // 4 bytes MPEG Audio-specific header + ES data + auto rtp = getRtpInfo().makeRtp(TrackAudio, nullptr, len + kMP2AHeaderSize, mark, stamp); + auto payload = rtp->getPayload(); + + // MPEG Audio-specific header + // MBZ (16 bits) = 0 + payload[0] = 0; + payload[1] = 0; + // Frag_offset (16 bits) + payload[2] = (frag_offset >> 8) & 0xFF; + payload[3] = frag_offset & 0xFF; + + // ES data + memcpy(payload + kMP2AHeaderSize, data, len); + + RtpCodec::inputRtp(std::move(rtp), false); +} + +bool MP2ARtpEncoder::inputFrame(const Frame::Ptr &frame) { + auto data = (const uint8_t *)frame->data() + frame->prefixSize(); + auto total_size = (size_t)(frame->size() - frame->prefixSize()); + if (total_size <= 0) { + return false; + } + + auto max_payload = getRtpInfo().getMaxSize() - kMP2AHeaderSize; + auto base_dts = frame->dts(); + + // TS demux 可能一次回调多个完整的 MPEG Audio 帧(一个 PES 包), + // 需要逐帧解析并独立打 RTP 包,否则 FFmpeg 等接收端会因为分片 + // 导致 RTP payload 不以 sync word 开头而报 "Header missing"。 + size_t pos = 0; + int frame_index = 0; + + while (pos + 4 <= total_size) { + // 检查 MPEG Audio sync word + if (data[pos] != 0xFF || (data[pos + 1] & 0xE0) != 0xE0) { + // 跳过无效字节,寻找下一个 sync word + ++pos; + continue; + } + + // 解析帧头获取帧大小 + MpegAudioFrameInfo info; + if (!MpegAudioFrameInfo::parse(data + pos, total_size - pos, info) || info.frame_size <= 0) { + ++pos; + continue; + } + + size_t frame_size = (size_t)info.frame_size; + if (pos + frame_size > total_size) { + // 不完整的帧,打包剩余数据 + frame_size = total_size - pos; + } + + // 计算当前帧的时间戳偏移(毫秒) + // 每帧 samples_per_frame 个采样点,采样率 info.sample_rate + uint64_t stamp = base_dts; + if (frame_index > 0 && info.sample_rate > 0) { + stamp += (uint64_t)frame_index * info.samples_per_frame * 1000 / info.sample_rate; + } + + // 对单个 MPEG Audio 帧打 RTP 包 + auto ptr = (const char *)(data + pos); + size_t remain = frame_size; + size_t frag_offset = 0; + + while (remain > 0) { + if (remain <= max_payload) { + outputRtp(ptr, remain, frag_offset, true, stamp); + break; + } + outputRtp(ptr, max_payload, frag_offset, false, stamp); + ptr += max_payload; + remain -= max_payload; + frag_offset += max_payload; + } + + pos += frame_size; + ++frame_index; + } + + return true; +} + +// ======================== MP2ARtpDecoder ======================== + +MP2ARtpDecoder::MP2ARtpDecoder() { + obtainFrame(); +} + +void MP2ARtpDecoder::obtainFrame() { + _frame = FrameImp::create(); +} + +void MP2ARtpDecoder::flushData() { + if (_frame->_buffer.empty()) { + return; + } + RtpCodec::inputFrame(_frame); + obtainFrame(); +} + +bool MP2ARtpDecoder::inputRtp(const RtpPacket::Ptr &rtp, bool key_pos) { + auto payload_size = rtp->getPayloadSize(); + if (payload_size <= (ssize_t)kMP2AHeaderSize) { + // 负载太小,没有有效 ES 数据 + return false; + } + + auto payload = rtp->getPayload(); + auto stamp = rtp->getStamp(); + auto seq = rtp->getSeq(); + + // 解析 MPEG Audio-specific header (RFC 2250 Section 3.5) + // MBZ (16 bits) + Frag_offset (16 bits) + uint16_t frag_offset = (payload[2] << 8) | payload[3]; + + auto es_data = payload + kMP2AHeaderSize; + auto es_size = payload_size - kMP2AHeaderSize; + + if (frag_offset == 0) { + // frag_offset == 0 表示这是一个新帧(或完整帧)的开始 + // 先输出之前缓存的帧(如果有) + flushData(); + // 使用 90kHz 时间戳转换为毫秒 + _frame->_dts = rtp->getStampMS(); + _frame->_pts = _frame->_dts; + } else if (_frame->_buffer.empty()) { + // frag_offset != 0 但 buffer 为空,说明丢了第一个分片包,丢弃 + _last_seq = seq; + _last_stamp = stamp; + return false; + } else if (seq != (uint16_t)(_last_seq + 1)) { + // 分片包 seq 不连续,丢包了,丢弃当前帧 + WarnL << "mp2a rtp packet loss:" << _last_seq << " -> " << seq; + _frame->_buffer.clear(); + _last_seq = seq; + _last_stamp = stamp; + return false; + } + + _last_seq = seq; + _last_stamp = stamp; + + // 追加 ES 数据 + _frame->_buffer.append((char *)es_data, es_size); + + // mark bit 表示帧的最后一个 RTP 包,立即输出 + if (rtp->getHeader()->mark) { + flushData(); + } + + return false; +} + +} // namespace mediakit diff --git a/ext-codec/MP2ARtp.h b/ext-codec/MP2ARtp.h new file mode 100644 index 00000000..71c04324 --- /dev/null +++ b/ext-codec/MP2ARtp.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_MP2ARTP_H +#define ZLMEDIAKIT_MP2ARTP_H + +#include "MP2A.h" +#include "Rtsp/RtpCodec.h" + +namespace mediakit { + +// RFC 2250 Section 3.5 MPEG Audio-specific header (4 bytes) +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | MBZ | Frag_offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// MBZ: Must Be Zero (16 bits) +// Frag_offset: Byte offset into the audio frame for the data in this packet (16 bits) + +static constexpr size_t kMP2AHeaderSize = 4; + +/** + * MP2A (MPEG-1/2 Audio Layer I/II) RTP 编码器 + * RFC 2250 Section 3.5 + */ +class MP2ARtpEncoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + /** + * 输入 MPEG Audio 帧并打包为 RTP + * @param frame 帧数据 + */ + bool inputFrame(const Frame::Ptr &frame) override; + +private: + /** + * 输出一个 RTP 包 + * @param data ES 数据 + * @param len 数据长度 + * @param frag_offset 分片在帧内的偏移 + * @param mark 是否为帧最后一个包 + * @param stamp 时间戳(ms) + */ + void outputRtp(const char *data, size_t len, size_t frag_offset, bool mark, uint64_t stamp); +}; + +/** + * MP2A (MPEG-1/2 Audio Layer I/II) RTP 解码器 + * RFC 2250 Section 3.5 + */ +class MP2ARtpDecoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + MP2ARtpDecoder(); + + /** + * 输入 MPEG Audio RTP 包并解码 + * @param rtp rtp 数据包 + * @param key_pos 音频帧忽略此参数 + */ + bool inputRtp(const RtpPacket::Ptr &rtp, bool key_pos = false) override; + +private: + void obtainFrame(); + void flushData(); + +private: + uint16_t _last_seq = 0; + uint32_t _last_stamp = 0; + FrameImp::Ptr _frame; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_MP2ARTP_H diff --git a/ext-codec/MP2V.cpp b/ext-codec/MP2V.cpp new file mode 100644 index 00000000..6263a7a3 --- /dev/null +++ b/ext-codec/MP2V.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "MP2V.h" +#include "MP2VRtp.h" +#include "Extension/Factory.h" +#include "Rtsp/Rtsp.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +// MPEG-2 sequence header 帧率表 (ISO 13818-2 Table 6-4) +// MPEG-2 sequence header frame rate table +static const float s_mp2v_frame_rate_table[] = { + 0, // 0000 forbidden + 24000.0 / 1001, // 0001 23.976 + 24.0, // 0010 + 25.0, // 0011 + 30000.0 / 1001, // 0100 29.97 + 30.0, // 0101 + 50.0, // 0110 + 60000.0 / 1001, // 0111 59.94 + 60.0, // 1000 +}; + +void MP2VTrack::parseSequenceHeader(const uint8_t *data, size_t size) { + // 查找 sequence header start code: 00 00 01 B3 + // Look for sequence header start code: 00 00 01 B3 + for (size_t i = 0; i + 7 < size; ++i) { + if (data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x01 && data[i + 3] == 0xB3) { + // sequence_header() 结构: + // horizontal_size_value: 12 bits + // vertical_size_value: 12 bits + // aspect_ratio_information: 4 bits + // frame_rate_code: 4 bits + _width = (data[i + 4] << 4) | ((data[i + 5] >> 4) & 0x0F); + _height = ((data[i + 5] & 0x0F) << 8) | data[i + 6]; + uint8_t frame_rate_code = data[i + 7] & 0x0F; + if (frame_rate_code > 0 && frame_rate_code <= 8) { + _fps = s_mp2v_frame_rate_table[frame_rate_code]; + } + _seq_header_parsed = true; + return; + } + } +} + +bool MP2VTrack::inputFrame(const Frame::Ptr &frame) { + if (!_seq_header_parsed) { + parseSequenceHeader((const uint8_t *)frame->data() + frame->prefixSize(), + frame->size() - frame->prefixSize()); + } + return VideoTrackImp::inputFrame(frame); +} + +Sdp::Ptr MP2VTrack::getSdp(uint8_t pt) const { + return std::make_shared(pt, *this); +} + +namespace { + +CodecId getCodec() { + return CodecMP2V; +} + +Track::Ptr getTrackByCodecId(int sample_rate, int channels, int sample_bit) { + return std::make_shared(); +} + +Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpDecoderByCodecId() { + return std::make_shared(); +} + +RtmpCodec::Ptr getRtmpEncoderByTrack(const Track::Ptr &track) { + WarnL << "Unsupported MP2V rtmp encoder"; + return nullptr; +} + +RtmpCodec::Ptr getRtmpDecoderByTrack(const Track::Ptr &track) { + WarnL << "Unsupported MP2V rtmp decoder"; + return nullptr; +} + +Frame::Ptr getFrameFromPtr(const char *data, size_t bytes, uint64_t dts, uint64_t pts) { + return std::make_shared((char *)data, bytes, dts, pts, 0); +} + +} // namespace + +CodecPlugin mp2v_plugin = { getCodec, + getTrackByCodecId, + getTrackBySdp, + getRtpEncoderByCodecId, + getRtpDecoderByCodecId, + getRtmpEncoderByTrack, + getRtmpDecoderByTrack, + getFrameFromPtr }; + +} // namespace mediakit diff --git a/ext-codec/MP2V.h b/ext-codec/MP2V.h new file mode 100644 index 00000000..3c8ce287 --- /dev/null +++ b/ext-codec/MP2V.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_MP2V_H +#define ZLMEDIAKIT_MP2V_H + +#include "Extension/Frame.h" +#include "Extension/Track.h" + +namespace mediakit { + +/** + * MPEG-2 Video 帧辅助类模板 + * MPEG-2 Video frame helper class template + */ +template +class MP2VFrameHelper : public Parent { +public: + using Ptr = std::shared_ptr; + + template + MP2VFrameHelper(ARGS &&...args) + : Parent(std::forward(args)...) { + this->_codec_id = CodecMP2V; + } + + /** + * MPEG-2 视频起始码: 00 00 01 00 (picture_start_code) + * I帧判断:picture_coding_type == 1 (I-Picture) + * picture_coding_type 位于 picture header 的第 11-12 bit (从 temporal_reference 之后) + * + * MPEG-2 video start code: 00 00 01 00 (picture_start_code) + * I-frame detection: picture_coding_type == 1 (I-Picture) + */ + bool keyFrame() const override { + auto data = (const uint8_t *)this->data() + this->prefixSize(); + auto size = this->size() - this->prefixSize(); + return isMP2VKeyFrame(data, size); + } + + bool configFrame() const override { return false; } + + static bool isMP2VKeyFrame(const uint8_t *data, size_t size) { + // 查找 picture start code (00 00 01 00),然后检查 picture_coding_type + // Look for picture start code (00 00 01 00), then check picture_coding_type + for (size_t i = 0; i + 5 < size; ++i) { + if (data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x01 && data[i + 3] == 0x00) { + // picture header: temporal_reference(10bits) + picture_coding_type(3bits) + // picture_coding_type: 001 = I, 010 = P, 011 = B + uint8_t picture_coding_type = (data[i + 5] >> 3) & 0x07; + return picture_coding_type == 1; + } + } + return false; + } +}; + +/// MPEG-2 Video 帧类 +using MP2VFrame = MP2VFrameHelper; +using MP2VFrameNoCacheAble = MP2VFrameHelper; + +/** + * MPEG-2 Video Track + */ +class MP2VTrack : public VideoTrackImp { +public: + using Ptr = std::shared_ptr; + + MP2VTrack() : VideoTrackImp(CodecMP2V) {} + + Track::Ptr clone() const override { return std::make_shared(*this); } + + bool inputFrame(const Frame::Ptr &frame) override; + +private: + Sdp::Ptr getSdp(uint8_t payload_type) const override; + + /** + * 从 sequence header 中解析宽高和帧率 + * Parse width, height and fps from sequence header + */ + void parseSequenceHeader(const uint8_t *data, size_t size); + +private: + bool _seq_header_parsed = false; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_MP2V_H diff --git a/ext-codec/MP2VRtp.cpp b/ext-codec/MP2VRtp.cpp new file mode 100644 index 00000000..afada09d --- /dev/null +++ b/ext-codec/MP2VRtp.cpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "MP2VRtp.h" +#include "Common/config.h" + +namespace mediakit { + +// ======================== MP2VRtpDecoder ======================== + +MP2VRtpDecoder::MP2VRtpDecoder() { + obtainFrame(); +} + +void MP2VRtpDecoder::obtainFrame() { + _frame = FrameImp::create(); +} + +bool MP2VRtpDecoder::inputRtp(const RtpPacket::Ptr &rtp, bool key_pos) { + auto seq = rtp->getSeq(); + auto last_gop_dropped = _gop_dropped; + bool is_gop_start = decodeRtp(rtp); + if (!_gop_dropped && seq != (uint16_t)(_last_seq + 1) && _last_seq) { + _gop_dropped = true; + WarnL << "start drop mp2v gop, last seq:" << _last_seq << ", rtp:\r\n" << rtp->dumpString(); + } + _last_seq = seq; + return is_gop_start && !last_gop_dropped; +} + +/** + * RFC 2250 MPEG Video-specific header (4 bytes): + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | MBZ |T| TR |AN|N|S|B|E| P | | BFC | | FFC | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * FBV FFV + * + * T: MPEG-2 specific header extension present (1 bit) + * TR: Temporal Reference (10 bits) + * AN: Active N bit (1 bit) + * N: New picture header (1 bit) + * S: Sequence-header-present (1 bit) + * B: Beginning-of-slice (1 bit) + * E: End-of-slice (1 bit) + * P: Picture-Type (3 bits): I(1), P(2), B(3), D(4) + * FBV: full_pel_backward_vector (1 bit) + * BFC: backward_f_code (3 bits) + * FFV: full_pel_forward_vector (1 bit) + * FFC: forward_f_code (3 bits) + */ +bool MP2VRtpDecoder::decodeRtp(const RtpPacket::Ptr &rtp) { + auto payload_size = rtp->getPayloadSize(); + if (payload_size <= (ssize_t)kMP2VHeaderSize) { + // 负载太小,不包含有效数据 + return false; + } + auto payload = rtp->getPayload(); + auto stamp = rtp->getStampMS(); + auto seq = rtp->getSeq(); + + // 解析 RFC 2250 MPEG Video-specific header + bool t_bit = (payload[0] >> 2) & 0x01; + // uint16_t temporal_ref = ((payload[0] & 0x03) << 8) | payload[1]; + // bool seq_header_present = (payload[2] >> 5) & 0x01; + // bool begin_of_slice = (payload[2] >> 4) & 0x01; + // bool end_of_slice = (payload[2] >> 3) & 0x01; + uint8_t picture_type = (payload[2] & 0x07); + + // 如果 T bit 置位,还有 4 字节的 MPEG-2 扩展头需要跳过 + size_t header_size = kMP2VHeaderSize + (t_bit ? 4 : 0); + if (payload_size <= (ssize_t)header_size) { + return false; + } + + auto es_data = payload + header_size; + auto es_size = payload_size - header_size; + + // 检查是否为新帧(时间戳变化) + if (!_frame->_buffer.empty() && stamp != _frame->_pts) { + // 时间戳变化,输出上一帧 + outputFrame(rtp); + } + + if (_frame->_buffer.empty()) { + // 新帧开始 + _frame->_pts = stamp; + _drop_flag = false; + _picture_type = picture_type; + } + + if (_drop_flag) { + return false; + } + + // 检测 seq 不连续,丢弃当前帧 + if (!_frame->_buffer.empty() && seq != (uint16_t)(_last_seq + 1) && _last_seq) { + _drop_flag = true; + _frame->_buffer.clear(); + return false; + } + + // 追加 ES 数据 + _frame->_buffer.append((char *)es_data, es_size); + + // RTP mark bit 标识帧结束 + if (rtp->getHeader()->mark) { + outputFrame(rtp); + return _picture_type == 1; // I-Picture + } + + return false; +} + +void MP2VRtpDecoder::outputFrame(const RtpPacket::Ptr &rtp) { + if (_frame->_buffer.empty()) { + return; + } + + // 生成 DTS(MPEG-2 有 B 帧,PTS 和 DTS 不一定相同) + _dts_generator.getDts(_frame->_pts, _frame->_dts); + + bool is_key = _frame->keyFrame(); + if (is_key && _gop_dropped) { + _gop_dropped = false; + InfoL << "new mp2v gop received, rtp:\r\n" << rtp->dumpString(); + } + if (!_gop_dropped) { + RtpCodec::inputFrame(_frame); + } + obtainFrame(); +} + +// ======================== MP2VRtpEncoder ======================== + +bool MP2VRtpEncoder::hasSequenceHeader(const uint8_t *data, size_t size) { + // 查找 sequence header start code: 00 00 01 B3 + for (size_t i = 0; i + 3 < size; ++i) { + if (data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x01 && data[i + 3] == 0xB3) { + return true; + } + } + return false; +} + +void MP2VRtpEncoder::parsePictureInfo(const uint8_t *data, size_t size) { + _temporal_ref = 0; + _picture_type = 0; + _fbv = 0; + _bfc = 0; + _ffv = 0; + _ffc = 0; + _has_seq_header = hasSequenceHeader(data, size); + + // 查找 picture start code: 00 00 01 00 + for (size_t i = 0; i + 5 < size; ++i) { + if (data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x01 && data[i + 3] == 0x00) { + // temporal_reference: 10 bits, picture_coding_type: 3 bits + _temporal_ref = (data[i + 4] << 2) | ((data[i + 5] >> 6) & 0x03); + _picture_type = (data[i + 5] >> 3) & 0x07; + + // 解析 motion vector codes (vbv_delay 之后) + // picture header: temporal_reference(10) + picture_coding_type(3) + vbv_delay(16) + if (i + 8 < size) { + uint8_t extra_byte = data[i + 8]; + if (_picture_type == 2 /* P */ || _picture_type == 3 /* B */) { + // full_pel_forward_vector(1) + forward_f_code(3) + _ffv = (extra_byte >> 2) & 0x01; + _ffc = ((extra_byte & 0x03) << 1); + if (i + 9 < size) { + _ffc |= (data[i + 9] >> 7) & 0x01; + } + } + if (_picture_type == 3 /* B */) { + // full_pel_backward_vector(1) + backward_f_code(3) 紧跟在 forward 之后 + if (i + 9 < size) { + _fbv = (data[i + 9] >> 6) & 0x01; + _bfc = (data[i + 9] >> 3) & 0x07; + } + } + } + return; + } + } +} + +void MP2VRtpEncoder::buildMpvHeader(uint8_t *buf, const uint8_t *data, size_t size, + bool is_begin_of_slice, bool is_end_of_slice) { + // RFC 2250 Section 3.4 + // Byte 0: MBZ(5) + T(1) + TR high 2 bits + // T = 0 (不发送 MPEG-2 扩展头) + buf[0] = (_temporal_ref >> 8) & 0x03; + + // Byte 1: TR low 8 bits + buf[1] = _temporal_ref & 0xFF; + + // Byte 2: AN(1) + N(1) + S(1) + B(1) + E(1) + P(3) + uint8_t byte2 = 0; + // AN = 0, N = 0 + if (_has_seq_header) { + byte2 |= 0x20; // S bit + } + if (is_begin_of_slice) { + byte2 |= 0x10; // B bit + } + if (is_end_of_slice) { + byte2 |= 0x08; // E bit + } + byte2 |= (_picture_type & 0x07); + buf[2] = byte2; + + // Byte 3: FBV(1) + BFC(3) + FFV(1) + FFC(3) + buf[3] = ((_fbv & 0x01) << 7) | ((_bfc & 0x07) << 4) | ((_ffv & 0x01) << 3) | (_ffc & 0x07); +} + +bool MP2VRtpEncoder::inputFrame(const Frame::Ptr &frame) { + auto ptr = (const uint8_t *)frame->data() + frame->prefixSize(); + auto size = frame->size() - frame->prefixSize(); + if (size == 0) { + return false; + } + + // 解析帧信息(picture type, temporal reference 等) + parsePictureInfo(ptr, size); + + bool is_key = frame->keyFrame(); + auto max_payload = getRtpInfo().getMaxSize() - kMP2VHeaderSize; + size_t offset = 0; + + while (offset < size) { + bool is_first = (offset == 0); + size_t payload_size; + bool is_last; + + if (size - offset <= max_payload) { + payload_size = size - offset; + is_last = true; + } else { + payload_size = max_payload; + is_last = false; + } + + // 构建 MPEG Video-specific header + uint8_t mpv_header[kMP2VHeaderSize]; + buildMpvHeader(mpv_header, ptr + offset, payload_size, is_first, is_last); + + // 创建 RTP 包:MPEG header + ES data + auto rtp = getRtpInfo().makeRtp(TrackVideo, nullptr, kMP2VHeaderSize + payload_size, is_last, frame->pts()); + auto rtp_payload = rtp->getPayload(); + + // 写入 MPEG Video-specific header + memcpy(rtp_payload, mpv_header, kMP2VHeaderSize); + // 写入 ES 数据 + memcpy(rtp_payload + kMP2VHeaderSize, ptr + offset, payload_size); + + // 输入到 RTP 环形缓存 + RtpCodec::inputRtp(rtp, is_key && is_first); + + offset += payload_size; + } + + return true; +} + +} // namespace mediakit diff --git a/ext-codec/MP2VRtp.h b/ext-codec/MP2VRtp.h new file mode 100644 index 00000000..8fd98581 --- /dev/null +++ b/ext-codec/MP2VRtp.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_MP2VRTP_H +#define ZLMEDIAKIT_MP2VRTP_H + +#include "MP2V.h" +#include "Common/Stamp.h" +#include "Rtsp/RtpCodec.h" + +namespace mediakit { + +// RFC 2250 MPEG Video-specific header (4 bytes) +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | MBZ |T| TR |N|S|B|E| P | | BFC | | FFC | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// AN FBV FFV + +static constexpr size_t kMP2VHeaderSize = 4; + +/** + * MP2V (MPEG-2 Video) RTP 解码器 + * 将 MPEG-2 Video over RTP 解复用出 MP2V Frame + * RFC 2250 + */ +class MP2VRtpDecoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + MP2VRtpDecoder(); + + /** + * 输入 MPEG-2 Video RTP 包 + * @param rtp rtp包 + * @param key_pos 此参数忽略之 + */ + bool inputRtp(const RtpPacket::Ptr &rtp, bool key_pos = true) override; + +private: + bool decodeRtp(const RtpPacket::Ptr &rtp); + void outputFrame(const RtpPacket::Ptr &rtp); + void obtainFrame(); + +private: + bool _gop_dropped = true; + bool _drop_flag = false; + uint16_t _last_seq = 0; + uint8_t _picture_type = 0; + MP2VFrame::Ptr _frame; + DtsGenerator _dts_generator; +}; + +/** + * MP2V (MPEG-2 Video) RTP 编码器 + * 将 MPEG-2 Video 帧打包为 RTP + * RFC 2250 + */ +class MP2VRtpEncoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + /** + * 输入 MPEG-2 Video 帧 + * @param frame 帧数据 + */ + bool inputFrame(const Frame::Ptr &frame) override; + +private: + /** + * 构建 RFC 2250 MPEG Video-specific header + * @param buf 输出缓冲区,至少4字节 + * @param data MPEG-2 ES 数据 + * @param size 数据大小 + * @param is_begin_of_slice 是否为 slice 起始 + * @param is_end_of_slice 是否为 slice 结束 + */ + void buildMpvHeader(uint8_t *buf, const uint8_t *data, size_t size, + bool is_begin_of_slice, bool is_end_of_slice); + + /** + * 解析当前帧信息(picture type, temporal reference 等) + */ + void parsePictureInfo(const uint8_t *data, size_t size); + + /** + * 查找 sequence header 是否存在 + */ + bool hasSequenceHeader(const uint8_t *data, size_t size); + +private: + uint16_t _temporal_ref = 0; + uint8_t _picture_type = 0; + uint8_t _fbv = 0; + uint8_t _bfc = 0; + uint8_t _ffv = 0; + uint8_t _ffc = 0; + bool _has_seq_header = false; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_MP2VRTP_H diff --git a/ext-codec/Opus.cpp b/ext-codec/Opus.cpp index 4ec57879..271bd926 100644 --- a/ext-codec/Opus.cpp +++ b/ext-codec/Opus.cpp @@ -11,16 +11,32 @@ #include "Opus.h" #include "Extension/Factory.h" #include "Extension/CommonRtp.h" -#include "Extension/CommonRtmp.h" - +#include "OpusRtmp.h" +#include "opus-head.h" using namespace std; using namespace toolkit; namespace mediakit { +void OpusTrack::setExtraData(const uint8_t *data, size_t size) { + opus_head_t header; + if (opus_head_load(data, size, &header) > 0) { + // Successfully parsed Opus header + _sample_rate = header.input_sample_rate; + _channels = header.channels; + } +} -Sdp::Ptr OpusTrack::getSdp(uint8_t payload_type) const { - return std::make_shared(payload_type, *this); +Buffer::Ptr OpusTrack::getExtraData() const { + struct opus_head_t opus {}; + opus.version = 1; + opus.channels = getAudioChannel(); + opus.input_sample_rate = getAudioSampleRate(); + // opus.pre_skip = 120; + opus.channel_mapping_family = 0; + auto ret = BufferRaw::create(29); + ret->setSize(opus_head_save(&opus, (uint8_t *)ret->data(), ret->getCapacity())); + return ret; } namespace { @@ -46,11 +62,11 @@ RtpCodec::Ptr getRtpDecoderByCodecId() { } RtmpCodec::Ptr getRtmpEncoderByTrack(const Track::Ptr &track) { - return std::make_shared(track); + return std::make_shared(track); } RtmpCodec::Ptr getRtmpDecoderByTrack(const Track::Ptr &track) { - return std::make_shared(track); + return std::make_shared(track); } Frame::Ptr getFrameFromPtr(const char *data, size_t bytes, uint64_t dts, uint64_t pts) { diff --git a/ext-codec/Opus.h b/ext-codec/Opus.h index a626b2f9..79a449a6 100644 --- a/ext-codec/Opus.h +++ b/ext-codec/Opus.h @@ -19,23 +19,20 @@ namespace mediakit { /** * Opus帧音频通道 * Opus frame audio channel - - * [AUTO-TRANSLATED:522e95da] */ -class OpusTrack : public AudioTrackImp{ +class OpusTrack : public AudioTrackImp { public: using Ptr = std::shared_ptr; OpusTrack() : AudioTrackImp(CodecOpus,48000,2,16){} private: - // 克隆该Track [AUTO-TRANSLATED:9a15682a] // Clone this Track Track::Ptr clone() const override { return std::make_shared(*this); } - // 生成sdp [AUTO-TRANSLATED:663a9367] - // Generate sdp - Sdp::Ptr getSdp(uint8_t payload_type) const override ; + + toolkit::Buffer::Ptr getExtraData() const override; + void setExtraData(const uint8_t *data, size_t size) override; }; }//namespace mediakit diff --git a/ext-codec/OpusRtmp.cpp b/ext-codec/OpusRtmp.cpp new file mode 100644 index 00000000..b7325363 --- /dev/null +++ b/ext-codec/OpusRtmp.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "OpusRtmp.h" +#include "Rtmp/utils.h" +#include "Common/config.h" +#include "Extension/Factory.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +void OpusRtmpDecoder::inputRtmp(const RtmpPacket::Ptr &pkt) { + auto data = pkt->data(); + int size = pkt->size(); + auto flags = (uint8_t)data[0]; + auto codec = (RtmpAudioCodec)(flags >> 4); + auto type = flags & 0x0F; + data++; size--; + if (codec == RtmpAudioCodec::ex_header) { + // @todo parse enhance audio header and check fourcc + data += 4; + size -= 4; + if (type == (uint8_t)RtmpPacketType::PacketTypeSequenceStart) { + getTrack()->setExtraData((uint8_t *)data, size); + } else { + outputFrame(data, size, pkt->time_stamp, pkt->time_stamp); + } + } else { + if (codec == RtmpAudioCodec::aac) { + uint8_t pkt_type = *data; + data++; size--; + if (pkt_type == (uint8_t)RtmpAACPacketType::aac_config_header) { + getTrack()->setExtraData((uint8_t *)data, size); + return; + } + } + outputFrame(data, size, pkt->time_stamp, pkt->time_stamp); + } +} + +void OpusRtmpDecoder::outputFrame(const char *data, size_t size, uint32_t dts, uint32_t pts) { + RtmpCodec::inputFrame(Factory::getFrameFromPtr(getTrack()->getCodecId(), data, size, dts, pts)); +} + +//////////////////////////////////////////////////////////////////////// +OpusRtmpEncoder::OpusRtmpEncoder(const Track::Ptr &track) : RtmpCodec(track) { + _enhanced = mINI::Instance()[Rtmp::kEnhanced]; +} + +bool OpusRtmpEncoder::inputFrame(const Frame::Ptr &frame) { + auto packet = RtmpPacket::create(); + if (_enhanced) { + uint8_t flags = ((uint8_t)RtmpAudioCodec::ex_header << 4) | (uint8_t)RtmpPacketType::PacketTypeCodedFrames; + packet->buffer.push_back(flags); + uint32_t fourcc = htonl(getCodecFourCC(getTrack()->getCodecId())); + packet->buffer.append(reinterpret_cast(&fourcc), 4); + } else { + uint8_t flags = getAudioRtmpFlags(getTrack()); + packet->buffer.push_back(flags); + if (getTrack()->getCodecId() == CodecAAC) { + packet->buffer.push_back((uint8_t)RtmpAACPacketType::aac_raw); + } + } + packet->buffer.append(frame->data(), frame->size()); + packet->body_size = packet->buffer.size(); + packet->time_stamp = frame->dts(); + packet->chunk_id = CHUNK_AUDIO; + packet->stream_index = STREAM_MEDIA; + packet->type_id = MSG_AUDIO; + // Output rtmp packet + RtmpCodec::inputRtmp(packet); + return true; +} + +void OpusRtmpEncoder::makeConfigPacket() { + auto extra_data = getTrack()->getExtraData(); + if (!extra_data || !extra_data->size()) + return; + auto packet = RtmpPacket::create(); + if (_enhanced) { + uint8_t flags = ((uint8_t)RtmpAudioCodec::ex_header << 4) | (uint8_t)RtmpPacketType::PacketTypeSequenceStart; + packet->buffer.push_back(flags); + uint32_t fourcc = htonl(getCodecFourCC(getTrack()->getCodecId())); + packet->buffer.append(reinterpret_cast(&fourcc), 4); + } else { + uint8_t flags = getAudioRtmpFlags(getTrack()); + packet->buffer.push_back(flags); + if (getTrack()->getCodecId() == CodecAAC) { + packet->buffer.push_back((uint8_t)RtmpAACPacketType::aac_config_header); + } + else{ + return ; + } + } + packet->buffer.append(extra_data->data(), extra_data->size()); + packet->body_size = packet->buffer.size(); + packet->chunk_id = CHUNK_AUDIO; + packet->stream_index = STREAM_MEDIA; + packet->time_stamp = 0; + packet->type_id = MSG_AUDIO; + RtmpCodec::inputRtmp(packet); +} + +} // namespace mediakit diff --git a/ext-codec/OpusRtmp.h b/ext-codec/OpusRtmp.h new file mode 100644 index 00000000..361a60f5 --- /dev/null +++ b/ext-codec/OpusRtmp.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_OPUS_RTMPCODEC_H +#define ZLMEDIAKIT_OPUS_RTMPCODEC_H + +#include "Rtmp/RtmpCodec.h" +#include "Extension/Track.h" + +namespace mediakit { +/** + * Rtmp解码类 + * 将 Opus over rtmp 解复用出 OpusFrame + */ +class OpusRtmpDecoder : public RtmpCodec { +public: + using Ptr = std::shared_ptr; + + OpusRtmpDecoder(const Track::Ptr &track) : RtmpCodec(track) {} + + void inputRtmp(const RtmpPacket::Ptr &rtmp) override; + +protected: + void outputFrame(const char *data, size_t size, uint32_t dts, uint32_t pts); +}; + +/** + * Rtmp打包类 + */ +class OpusRtmpEncoder : public RtmpCodec { + bool _enhanced = false; +public: + using Ptr = std::shared_ptr; + + OpusRtmpEncoder(const Track::Ptr &track); + + bool inputFrame(const Frame::Ptr &frame) override; + + void makeConfigPacket() override; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_OPUS_RTMPCODEC_H diff --git a/ext-codec/VP8.cpp b/ext-codec/VP8.cpp new file mode 100644 index 00000000..f01d3b59 --- /dev/null +++ b/ext-codec/VP8.cpp @@ -0,0 +1,79 @@ +#include "VP8.h" +#include "VP8Rtp.h" +#include "VpxRtmp.h" +#include "Extension/Factory.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +bool VP8Track::inputFrame(const Frame::Ptr &frame) { + char *dataPtr = frame->data() + frame->prefixSize(); + if (frame->keyFrame()) { + if (frame->size() - frame->prefixSize() < 10) + return false; + _width = ((dataPtr[7] << 8) + dataPtr[6]) & 0x3FFF; + _height = ((dataPtr[9] << 8) + dataPtr[8]) & 0x3FFF; + webm_vpx_codec_configuration_record_from_vp8(&_vpx, &_width, &_height, dataPtr, frame->size() - frame->prefixSize()); + // InfoL << _width << "x" << _height; + } + return VideoTrackImp::inputFrame(frame); +} + +Buffer::Ptr VP8Track::getExtraData() const { + auto ret = BufferRaw::create(8 + _vpx.codec_intialization_data_size); + ret->setSize(webm_vpx_codec_configuration_record_save(&_vpx, (uint8_t *)ret->data(), ret->getCapacity())); + return ret; +} + +void VP8Track::setExtraData(const uint8_t *data, size_t size) { + webm_vpx_codec_configuration_record_load(data, size, &_vpx); +} + +namespace { + +CodecId getCodec() { + return CodecVP8; +} + +Track::Ptr getTrackByCodecId(int sample_rate, int channels, int sample_bit) { + return std::make_shared(); +} + +Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpDecoderByCodecId() { + return std::make_shared(); +} + +RtmpCodec::Ptr getRtmpEncoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +RtmpCodec::Ptr getRtmpDecoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +Frame::Ptr getFrameFromPtr(const char *data, size_t bytes, uint64_t dts, uint64_t pts) { + return std::make_shared((char *)data, bytes, dts, pts, 0); +} + +} // namespace + +CodecPlugin vp8_plugin = { getCodec, + getTrackByCodecId, + getTrackBySdp, + getRtpEncoderByCodecId, + getRtpDecoderByCodecId, + getRtmpEncoderByTrack, + getRtmpDecoderByTrack, + getFrameFromPtr }; + +} // namespace mediakit \ No newline at end of file diff --git a/ext-codec/VP8.h b/ext-codec/VP8.h new file mode 100644 index 00000000..10009456 --- /dev/null +++ b/ext-codec/VP8.h @@ -0,0 +1,49 @@ +#ifndef ZLMEDIAKIT_VP8_H +#define ZLMEDIAKIT_VP8_H + +#include "Extension/Frame.h" +#include "Extension/Track.h" +#include "webm-vpx.h" +namespace mediakit { +template +class VP8FrameHelper : public Parent { +public: + friend class FrameImp; + //friend class toolkit::ResourcePool_l; + using Ptr = std::shared_ptr; + + template + VP8FrameHelper(ARGS &&...args) + : Parent(std::forward(args)...) { + this->_codec_id = CodecVP8; + } + + bool keyFrame() const override { + auto ptr = (uint8_t *) this->data() + this->prefixSize(); + return !(*ptr & 0x01); + } + bool configFrame() const override { return false; } + bool dropAble() const override { return false; } + bool decodeAble() const override { return true; } +}; + +/// VP8 帧类 +using VP8Frame = VP8FrameHelper; +using VP8FrameNoCacheAble = VP8FrameHelper; + +class VP8Track : public VideoTrackImp { +public: + VP8Track() : VideoTrackImp(CodecVP8) {} + + Track::Ptr clone() const override { return std::make_shared(*this); } + + bool inputFrame(const Frame::Ptr &frame) override; + toolkit::Buffer::Ptr getExtraData() const override; + void setExtraData(const uint8_t *data, size_t size) override; +private: + webm_vpx_t _vpx {}; +}; + +} // namespace mediakit + +#endif \ No newline at end of file diff --git a/ext-codec/VP8Rtp.cpp b/ext-codec/VP8Rtp.cpp new file mode 100644 index 00000000..3f0760ae --- /dev/null +++ b/ext-codec/VP8Rtp.cpp @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "VP8Rtp.h" +#include "Extension/Frame.h" +#include "Common/config.h" + +namespace mediakit{ + +const int16_t kNoPictureId = -1; +const int8_t kNoTl0PicIdx = -1; +const uint8_t kNoTemporalIdx = 0xFF; +const int kNoKeyIdx = -1; + +// internal bits +constexpr int kXBit = 0x80; +constexpr int kNBit = 0x20; +constexpr int kSBit = 0x10; +constexpr int kKeyIdxField = 0x1F; +constexpr int kIBit = 0x80; +constexpr int kLBit = 0x40; +constexpr int kTBit = 0x20; +constexpr int kKBit = 0x10; +constexpr int kYBit = 0x20; +constexpr int kFailedToParse = 0; +// VP8 payload descriptor +// https://datatracker.ietf.org/doc/html/rfc7741#section-4.2 +// +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |X|R|N|S|R| PID | (REQUIRED) +// +-+-+-+-+-+-+-+-+ +// X: |I|L|T|K| RSV | (OPTIONAL) +// +-+-+-+-+-+-+-+-+ +// I: |M| PictureID | (OPTIONAL) +// +-+-+-+-+-+-+-+-+ +// | PictureID | +// +-+-+-+-+-+-+-+-+ +// L: | TL0PICIDX | (OPTIONAL) +// +-+-+-+-+-+-+-+-+ +// T/K: |TID|Y| KEYIDX | (OPTIONAL) +// +-+-+-+-+-+-+-+-+ +struct RTPVideoHeaderVP8 { + void InitRTPVideoHeaderVP8(); + + int Size() const; + int Write(uint8_t *data, int size) const; + int Read(const uint8_t *data, int data_length); + bool isFirstPacket() const { return beginningOfPartition && partitionId == 0; } + friend bool operator!=(const RTPVideoHeaderVP8 &lhs, const RTPVideoHeaderVP8 &rhs) { return !(lhs == rhs); } + friend bool operator==(const RTPVideoHeaderVP8 &lhs, const RTPVideoHeaderVP8 &rhs) { + return lhs.nonReference == rhs.nonReference && lhs.pictureId == rhs.pictureId && lhs.tl0PicIdx == rhs.tl0PicIdx && lhs.temporalIdx == rhs.temporalIdx + && lhs.layerSync == rhs.layerSync && lhs.keyIdx == rhs.keyIdx && lhs.partitionId == rhs.partitionId + && lhs.beginningOfPartition == rhs.beginningOfPartition; + } + + bool nonReference; // Frame is discardable. + int16_t pictureId; // Picture ID index, 15 bits; + // kNoPictureId if PictureID does not exist. + int8_t tl0PicIdx; // TL0PIC_IDX, 8 bits; + // kNoTl0PicIdx means no value provided. + uint8_t temporalIdx; // Temporal layer index, or kNoTemporalIdx. + bool layerSync; // This frame is a layer sync frame. + // Disabled if temporalIdx == kNoTemporalIdx. + int8_t keyIdx; // 5 bits; kNoKeyIdx means not used. + int8_t partitionId; // VP8 partition ID + bool beginningOfPartition; // True if this packet is the first + // in a VP8 partition. Otherwise false +}; + +void RTPVideoHeaderVP8::InitRTPVideoHeaderVP8() { + nonReference = false; + pictureId = kNoPictureId; + tl0PicIdx = kNoTl0PicIdx; + temporalIdx = kNoTemporalIdx; + layerSync = false; + keyIdx = kNoKeyIdx; + partitionId = 0; + beginningOfPartition = false; +} + +int RTPVideoHeaderVP8::Size() const { + bool tid_present = this->temporalIdx != kNoTemporalIdx; + bool keyid_present = this->keyIdx != kNoKeyIdx; + bool tl0_pid_present = this->tl0PicIdx != kNoTl0PicIdx; + bool pid_present = this->pictureId != kNoPictureId; + int ret = 2; + if (pid_present) + ret += 2; + if (tl0_pid_present) + ret++; + if (tid_present || keyid_present) + ret++; + return ret == 2 ? 1 : ret; +} + +int RTPVideoHeaderVP8::Write(uint8_t *data, int size) const { + int ret = 0; + bool tid_present = this->temporalIdx != kNoTemporalIdx; + bool keyid_present = this->keyIdx != kNoKeyIdx; + bool tl0_pid_present = this->tl0PicIdx != kNoTl0PicIdx; + bool pid_present = this->pictureId != kNoPictureId; + uint8_t x_field = 0; + if (pid_present) + x_field |= kIBit; + if (tl0_pid_present) + x_field |= kLBit; + if (tid_present) + x_field |= kTBit; + if (keyid_present) + x_field |= kKBit; + + uint8_t flags = 0; + if (x_field != 0) + flags |= kXBit; + if (this->nonReference) + flags |= kNBit; + // Create header as first packet in the frame. NextPacket() will clear it + // after first use. + flags |= kSBit; + data[ret++] = flags; + if (x_field == 0) { + return ret; + } + data[ret++] = x_field; + if (pid_present) { + const uint16_t pic_id = static_cast(this->pictureId); + data[ret++] = (0x80 | ((pic_id >> 8) & 0x7F)); + data[ret++] = (pic_id & 0xFF); + } + if (tl0_pid_present) { + data[ret++] = this->tl0PicIdx; + } + if (tid_present || keyid_present) { + uint8_t data_field = 0; + if (tid_present) { + data_field |= this->temporalIdx << 6; + if (this->layerSync) + data_field |= kYBit; + } + if (keyid_present) { + data_field |= (this->keyIdx & kKeyIdxField); + } + data[ret++] = data_field; + } + return ret; +} + +int RTPVideoHeaderVP8::Read(const uint8_t *data, int data_length) { + // RTC_DCHECK_GT(data_length, 0); + int parsed_bytes = 0; + // Parse mandatory first byte of payload descriptor. + bool extension = (*data & 0x80) ? true : false; // X bit + this->nonReference = (*data & 0x20) ? true : false; // N bit + this->beginningOfPartition = (*data & 0x10) ? true : false; // S bit + this->partitionId = (*data & 0x07); // PID field + + data++; + parsed_bytes++; + data_length--; + + if (!extension) + return parsed_bytes; + + if (data_length == 0) + return kFailedToParse; + // Optional X field is present. + bool has_picture_id = (*data & 0x80) ? true : false; // I bit + bool has_tl0_pic_idx = (*data & 0x40) ? true : false; // L bit + bool has_tid = (*data & 0x20) ? true : false; // T bit + bool has_key_idx = (*data & 0x10) ? true : false; // K bit + + // Advance data and decrease remaining payload size. + data++; + parsed_bytes++; + data_length--; + + if (has_picture_id) { + if (data_length == 0) + return kFailedToParse; + + this->pictureId = (*data & 0x7F); + if (*data & 0x80) { + data++; + parsed_bytes++; + if (--data_length == 0) + return kFailedToParse; + // PictureId is 15 bits + this->pictureId = (this->pictureId << 8) + *data; + } + data++; + parsed_bytes++; + data_length--; + } + + if (has_tl0_pic_idx) { + if (data_length == 0) + return kFailedToParse; + + this->tl0PicIdx = *data; + data++; + parsed_bytes++; + data_length--; + } + + if (has_tid || has_key_idx) { + if (data_length == 0) + return kFailedToParse; + + if (has_tid) { + this->temporalIdx = ((*data >> 6) & 0x03); + this->layerSync = (*data & 0x20) ? true : false; // Y bit + } + if (has_key_idx) { + this->keyIdx = *data & 0x1F; + } + data++; + parsed_bytes++; + data_length--; + } + return parsed_bytes; +} + +///////////////////////////////////////////////// +// VP8RtpDecoder +VP8RtpDecoder::VP8RtpDecoder() { + obtainFrame(); +} + +void VP8RtpDecoder::obtainFrame() { + _frame = FrameImp::create(); +} + +bool VP8RtpDecoder::inputRtp(const RtpPacket::Ptr &rtp, bool key_pos) { + auto seq = rtp->getSeq(); + bool ret = decodeRtp(rtp); + if (!_gop_dropped && seq != (uint16_t)(_last_seq + 1) && _last_seq) { + _gop_dropped = true; + WarnL << "start drop vp8 gop, last seq:" << _last_seq << ", rtp:\r\n" << rtp->dumpString(); + } + _last_seq = seq; + return ret; +} + +bool VP8RtpDecoder::decodeRtp(const RtpPacket::Ptr &rtp) { + auto payload_size = rtp->getPayloadSize(); + if (payload_size <= 0) { + // No actual payload + return false; + } + auto payload = rtp->getPayload(); + auto stamp = rtp->getStampMS(); + auto seq = rtp->getSeq(); + + RTPVideoHeaderVP8 info; + int offset = info.Read(payload, payload_size); + if (!offset) { + //_frame_drop = true; + return false; + } + bool start = info.isFirstPacket(); + if (start) { + _frame->_pts = stamp; + _frame->_buffer.clear(); + _frame_drop = false; + } + + if (_frame_drop) { + // This frame is incomplete + return false; + } + + if (!start && seq != (uint16_t)(_last_seq + 1)) { + // 中间的或末尾的rtp包,其seq必须连续,否则说明rtp丢包,那么该帧不完整,必须得丢弃 + _frame_drop = true; + _frame->_buffer.clear(); + return false; + } + // Append data + _frame->_buffer.append((char *)payload + offset, payload_size - offset); + bool end = rtp->getHeader()->mark; + if (end) { + // 确保下一次fu必须收到第一个包 + _frame_drop = true; + // 该帧最后一个rtp包,输出frame [AUTO-TRANSLATED:a648aaa5] + // The last rtp packet of this frame, output frame + outputFrame(rtp); + } + + return (info.isFirstPacket() && (payload[offset] & 0x01) == 0); +} + +void VP8RtpDecoder::outputFrame(const RtpPacket::Ptr &rtp) { + if (_frame->dropAble()) { + // 不参与dts生成 [AUTO-TRANSLATED:dff3b747] + // Not involved in dts generation + _frame->_dts = _frame->_pts; + } else { + // rtsp没有dts,那么根据pts排序算法生成dts [AUTO-TRANSLATED:f37c17f3] + // Rtsp does not have dts, so dts is generated according to the pts sorting algorithm + _dts_generator.getDts(_frame->_pts, _frame->_dts); + } + + if (_frame->keyFrame() && _gop_dropped) { + _gop_dropped = false; + InfoL << "new gop received, rtp:\r\n" << rtp->dumpString(); + } + if (!_gop_dropped || _frame->configFrame()) { + RtpCodec::inputFrame(_frame); + } + obtainFrame(); +} + +//////////////////////////////////////////////////////////////////////// + +bool VP8RtpEncoder::inputFrame(const Frame::Ptr &frame) { + RTPVideoHeaderVP8 info; + info.InitRTPVideoHeaderVP8(); + info.beginningOfPartition = true; + info.nonReference = !frame->dropAble(); + uint8_t header[20]; + int header_size = info.Write(header, sizeof(header)); + + int pdu_size = getRtpInfo().getMaxSize() - header_size; + const char *ptr = frame->data() + frame->prefixSize(); + size_t len = frame->size() - frame->prefixSize(); + bool key = frame->keyFrame(); + bool mark = false; + for (size_t pos = 0; pos < len; pos += pdu_size) { + if (static_cast(len - pos) <= pdu_size) { + pdu_size = len - pos; + mark = true; + } + + auto rtp = getRtpInfo().makeRtp(TrackVideo, nullptr, pdu_size + header_size, mark, frame->pts()); + if (rtp) { + uint8_t *payload = rtp->getPayload(); + memcpy(payload, header, header_size); + memcpy(payload + header_size, ptr + pos, pdu_size); + RtpCodec::inputRtp(rtp, key); + } + + key = false; + header[0] &= (~kSBit); // Clear 'Start of partition' bit. + } + return true; +} + +} // namespace mediakit diff --git a/ext-codec/VP8Rtp.h b/ext-codec/VP8Rtp.h new file mode 100644 index 00000000..5b4dd0db --- /dev/null +++ b/ext-codec/VP8Rtp.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_VP8RTPCODEC_H +#define ZLMEDIAKIT_VP8RTPCODEC_H + +#include "VP8.h" +// for DtsGenerator +#include "Common/Stamp.h" +#include "Rtsp/RtpCodec.h" + +namespace mediakit { + +/** + * vp8 rtp解码类 + * 将 vp8 over rtsp-rtp 解复用出 VP8Frame + */ +class VP8RtpDecoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + VP8RtpDecoder(); + + /** + * 输入vp8 rtp包 + * @param rtp rtp包 + * @param key_pos 此参数忽略之 + */ + bool inputRtp(const RtpPacket::Ptr &rtp, bool key_pos = true) override; + +private: + bool decodeRtp(const RtpPacket::Ptr &rtp); + void outputFrame(const RtpPacket::Ptr &rtp); + void obtainFrame(); + +private: + bool _gop_dropped = false; + bool _frame_drop = true; + uint16_t _last_seq = 0; + VP8Frame::Ptr _frame; + DtsGenerator _dts_generator; +}; + +/** + * vp8 rtp打包类 + */ +class VP8RtpEncoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + bool inputFrame(const Frame::Ptr &frame) override; +}; + +}//namespace mediakit + +#endif //ZLMEDIAKIT_VP8RTPCODEC_H diff --git a/ext-codec/VP9.cpp b/ext-codec/VP9.cpp new file mode 100644 index 00000000..b5873180 --- /dev/null +++ b/ext-codec/VP9.cpp @@ -0,0 +1,76 @@ +#include "VP9.h" +#include "VP9Rtp.h" +#include "VpxRtmp.h" +#include "Extension/Factory.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +bool VP9Track::inputFrame(const Frame::Ptr &frame) { + char *dataPtr = frame->data() + frame->prefixSize(); + if (frame->keyFrame()) { + if (frame->size() - frame->prefixSize() < 10) + return false; + webm_vpx_codec_configuration_record_from_vp9(&_vpx, &_width, &_height, dataPtr, frame->size() - frame->prefixSize()); + } + return VideoTrackImp::inputFrame(frame); +} + +Buffer::Ptr VP9Track::getExtraData() const { + auto ret = BufferRaw::create(8 + _vpx.codec_intialization_data_size); + ret->setSize(webm_vpx_codec_configuration_record_save(&_vpx, (uint8_t *)ret->data(), ret->getCapacity())); + return ret; +} + +void VP9Track::setExtraData(const uint8_t *data, size_t size) { + webm_vpx_codec_configuration_record_load(data, size, &_vpx); +} + +namespace { + +CodecId getCodec() { + return CodecVP9; +} + +Track::Ptr getTrackByCodecId(int sample_rate, int channels, int sample_bit) { + return std::make_shared(); +} + +Track::Ptr getTrackBySdp(const SdpTrack::Ptr &track) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpEncoderByCodecId(uint8_t pt) { + return std::make_shared(); +} + +RtpCodec::Ptr getRtpDecoderByCodecId() { + return std::make_shared(); +} + +RtmpCodec::Ptr getRtmpEncoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +RtmpCodec::Ptr getRtmpDecoderByTrack(const Track::Ptr &track) { + return std::make_shared(track); +} + +Frame::Ptr getFrameFromPtr(const char *data, size_t bytes, uint64_t dts, uint64_t pts) { + return std::make_shared((char *)data, bytes, dts, pts, 0); +} + +} // namespace + +CodecPlugin vp9_plugin = { getCodec, + getTrackByCodecId, + getTrackBySdp, + getRtpEncoderByCodecId, + getRtpDecoderByCodecId, + getRtmpEncoderByTrack, + getRtmpDecoderByTrack, + getFrameFromPtr }; + +} // namespace mediakit \ No newline at end of file diff --git a/ext-codec/VP9.h b/ext-codec/VP9.h new file mode 100644 index 00000000..c99cd70f --- /dev/null +++ b/ext-codec/VP9.h @@ -0,0 +1,49 @@ +#ifndef ZLMEDIAKIT_VP9_H +#define ZLMEDIAKIT_VP9_H + +#include "Extension/Frame.h" +#include "Extension/Track.h" +#include "webm-vpx.h" +namespace mediakit { +template +class VP9FrameHelper : public Parent { +public: + friend class FrameImp; + //friend class toolkit::ResourcePool_l; + using Ptr = std::shared_ptr; + + template + VP9FrameHelper(ARGS &&...args) + : Parent(std::forward(args)...) { + this->_codec_id = CodecVP9; + } + + bool keyFrame() const override { + auto ptr = (uint8_t *) this->data() + this->prefixSize(); + return (*ptr & 0x80); + } + bool configFrame() const override { return false; } + bool dropAble() const override { return false; } + bool decodeAble() const override { return true; } +}; + +/// VP9 帧类 +using VP9Frame = VP9FrameHelper; +using VP9FrameNoCacheAble = VP9FrameHelper; + +class VP9Track : public VideoTrackImp { +public: + VP9Track() : VideoTrackImp(CodecVP9) {}; + + Track::Ptr clone() const override { return std::make_shared(*this); } + + bool inputFrame(const Frame::Ptr &frame) override; + toolkit::Buffer::Ptr getExtraData() const override; + void setExtraData(const uint8_t *data, size_t size) override; +private: + webm_vpx_t _vpx {}; +}; + +} // namespace mediakit + +#endif \ No newline at end of file diff --git a/ext-codec/VP9Rtp.cpp b/ext-codec/VP9Rtp.cpp new file mode 100644 index 00000000..33445e87 --- /dev/null +++ b/ext-codec/VP9Rtp.cpp @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "VP9Rtp.h" +#include "Extension/Frame.h" +#include "Common/config.h" + +namespace mediakit{ + +const int16_t kNoPictureId = -1; +const int8_t kNoTl0PicIdx = -1; +const uint8_t kNoTemporalIdx = 0xFF; +const int kNoKeyIdx = -1; + +struct VP9ResolutionLayer { + int width; + int height; +}; + +struct RTPPayloadVP9 { + bool hasPictureID = false; + bool interPicturePrediction = false; + bool hasLayerIndices = false; + bool flexibleMode = false; + bool beginningOfLayerFrame = false; + bool endingOfLayerFrame = false; + bool hasScalabilityStructure = false; + bool largePictureID = false; + int pictureID = -1; + int temporalID = -1; + bool isSwitchingUp = false; + int spatialID = -1; + bool isInterLayeredDepUsed = false; + int tl0PicIdx = -1; + int referenceIdx = -1; + bool additionalReferenceIdx = false; + int spatialLayers = -1; + bool hasResolution = false; + bool hasGof = false; + int numberOfFramesInGof = -1; + std::vector resolutions; + int parse(unsigned char* data, int dataLength); + bool keyFrame() const { return beginningOfLayerFrame && !interPicturePrediction; } + std::string dump() const { + char line[64] = {0}; + snprintf(line, sizeof(line), "%c%c%c%c%c%c%c- %d %d, %d %d", + hasPictureID ? 'I' : ' ', + interPicturePrediction ? 'P' : ' ', + hasLayerIndices ? 'L' : ' ', + flexibleMode ? 'F' : ' ', + beginningOfLayerFrame ? 'B' : ' ', + endingOfLayerFrame ? 'E' : ' ', + hasScalabilityStructure ? 'V' : ' ', + pictureID, tl0PicIdx, + spatialID, temporalID); + return line; + } +}; +// +// VP9 format: +// +// Payload descriptor (Flexible mode F = 1) +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |I|P|L|F|B|E|V|-| (REQUIRED) +// +-+-+-+-+-+-+-+-+ +// I: |M| PICTURE ID | (REQUIRED) +// +-+-+-+-+-+-+-+-+ +// M: | EXTENDED PID | (RECOMMENDED) +// +-+-+-+-+-+-+-+-+ +// L: | T |U| S |D| (CONDITIONALLY RECOMMENDED) +// +-+-+-+-+-+-+-+-+ - +// P,F: | P_DIFF |N| (CONDITIONALLY REQUIRED) - up to 3 times +// +-+-+-+-+-+-+-+-+ - +// V: | SS | +// | .. | +// +-+-+-+-+-+-+-+-+ +// +// Payload descriptor (Non flexible mode F = 0) +// +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |I|P|L|F|B|E|V|-| (REQUIRED) +// +-+-+-+-+-+-+-+-+ +// I: |M| PICTURE ID | (RECOMMENDED) +// +-+-+-+-+-+-+-+-+ +// M: | EXTENDED PID | (RECOMMENDED) +// +-+-+-+-+-+-+-+-+ +// L: | T |U| S |D| (CONDITIONALLY RECOMMENDED) +// +-+-+-+-+-+-+-+-+ +// | TL0PICIDX | (CONDITIONALLY REQUIRED) +// +-+-+-+-+-+-+-+-+ +// V: | SS | +// | .. | +// +-+-+-+-+-+-+-+-+ +#define kIBit 0x80 +#define kPBit 0x40 +#define kLBit 0x20 +#define kFBit 0x10 +#define kBBit 0x08 +#define kEBit 0x04 +#define kVBit 0x02 +int RTPPayloadVP9::parse(unsigned char *data, int dataLength) { + const unsigned char* dataPtr = data; + const unsigned char* dataEnd = data + dataLength; + +#define VP9_CHECK_BOUNDS(n) do { if (dataPtr + (n) > dataEnd) return -1; } while (0) + + // Parse mandatory first byte of payload descriptor + VP9_CHECK_BOUNDS(1); + this->hasPictureID = (*dataPtr & kIBit); // I bit + this->interPicturePrediction = (*dataPtr & kPBit); // P bit + this->hasLayerIndices = (*dataPtr & kLBit); // L bit + this->flexibleMode = (*dataPtr & kFBit); // F bit + this->beginningOfLayerFrame = (*dataPtr & kBBit); // B bit + this->endingOfLayerFrame = (*dataPtr & kEBit); // E bit + this->hasScalabilityStructure = (*dataPtr & kVBit); // V bit + dataPtr++; + + if (this->hasPictureID) { + VP9_CHECK_BOUNDS(1); + this->largePictureID = (*dataPtr & 0x80); // M bit + this->pictureID = (*dataPtr & 0x7F); + if (this->largePictureID) { + dataPtr++; + VP9_CHECK_BOUNDS(1); + this->pictureID = ntohs((this->pictureID << 16) + (*dataPtr & 0xFF)); + } + dataPtr++; + } + + if (this->hasLayerIndices) { + VP9_CHECK_BOUNDS(1); + this->temporalID = (*dataPtr & 0xE0) >> 5; // T bits + this->isSwitchingUp = (*dataPtr & 0x10); // U bit + this->spatialID = (*dataPtr & 0x0E) >> 1; // S bits + this->isInterLayeredDepUsed = (*dataPtr & 0x01); // D bit + if (this->flexibleMode) { // marked in webrtc code + do { + dataPtr++; + VP9_CHECK_BOUNDS(1); + this->referenceIdx = (*dataPtr & 0xFE) >> 1; + this->additionalReferenceIdx = (*dataPtr & 0x01); // D bit + } while (this->additionalReferenceIdx); + } else { + dataPtr++; + VP9_CHECK_BOUNDS(1); + this->tl0PicIdx = (*dataPtr & 0xFF); + } + dataPtr++; + } + + if (this->flexibleMode && this->interPicturePrediction) { + /* Skip reference indices */ + uint8_t nbit; + do { + VP9_CHECK_BOUNDS(1); + uint8_t p_diff = (*dataPtr & 0xFE) >> 1; + nbit = (*dataPtr & 0x01); + dataPtr++; + } while (nbit); + } + if (this->hasScalabilityStructure) { + VP9_CHECK_BOUNDS(1); + this->spatialLayers = (*dataPtr & 0xE0) >> 5; // N_S bits + this->hasResolution = (*dataPtr & 0x10); // Y bit + this->hasGof = (*dataPtr & 0x08); // G bit + dataPtr++; + if (this->hasResolution) { + for (int i = 0; i <= this->spatialLayers; i++) { + VP9_CHECK_BOUNDS(4); + int width = (dataPtr[0] << 8) + dataPtr[1]; + dataPtr += 2; + int height = (dataPtr[0] << 8) + dataPtr[1]; + dataPtr += 2; + // InfoL << "got vp9 " << width << "x" << height; + this->resolutions.push_back({ width, height }); + } + } + if (this->hasGof) { + VP9_CHECK_BOUNDS(1); + this->numberOfFramesInGof = *dataPtr & 0xFF; // N_G bits + dataPtr++; + for (int frame_index = 0; frame_index < this->numberOfFramesInGof; frame_index++) { + // TODO(javierc): Read these values if needed + VP9_CHECK_BOUNDS(1); + int reference_indices = (*dataPtr & 0x0C) >> 2; // R bits + dataPtr++; + VP9_CHECK_BOUNDS(reference_indices); + for (int reference_index = 0; reference_index < reference_indices; reference_index++) { + dataPtr++; + } + } + } + } + +#undef VP9_CHECK_BOUNDS + + return dataPtr - data; +} + + +//////////////////////////////////////////////////// +VP9RtpDecoder::VP9RtpDecoder() { + obtainFrame(); +} + +void VP9RtpDecoder::obtainFrame() { + _frame = FrameImp::create(); +} + +bool VP9RtpDecoder::inputRtp(const RtpPacket::Ptr &rtp, bool key_pos) { + auto seq = rtp->getSeq(); + bool is_gop = decodeRtp(rtp); + if (!_gop_dropped && seq != (uint16_t)(_last_seq + 1) && _last_seq) { + _gop_dropped = true; + WarnL << "start drop VP9 gop, last seq:" << _last_seq << ", rtp:\r\n" << rtp->dumpString(); + } + _last_seq = seq; + return is_gop; +} + +bool VP9RtpDecoder::decodeRtp(const RtpPacket::Ptr &rtp) { + auto payload_size = rtp->getPayloadSize(); + if (payload_size < 1) { + // No actual payload + return false; + } + auto payload = rtp->getPayload(); + auto stamp = rtp->getStampMS(); + auto seq = rtp->getSeq(); + + RTPPayloadVP9 info; + int offset = info.parse(payload, payload_size); + if (offset < 0) { + WarnL << "VP9 RTP payload parse failed, seq:" << seq; + return false; + } + // InfoL << rtp->dumpString() << "\n" << info.dump(); + bool start = info.beginningOfLayerFrame; + if (start) { + _frame->_pts = stamp; + _frame->_buffer.clear(); + _frame_drop = false; + } + + if (_frame_drop) { + // This frame is incomplete + return false; + } + + if (!start && seq != (uint16_t)(_last_seq + 1)) { + // 中间的或末尾的rtp包,其seq必须连续,否则说明rtp丢包,那么该帧不完整,必须得丢弃 + _frame_drop = true; + _frame->_buffer.clear(); + return false; + } + // Append data + _frame->_buffer.append((char *)payload + offset, payload_size - offset); + if (info.endingOfLayerFrame) { // rtp->getHeader()->mark + // 确保下一个包必须是beginningOfLayerFrame + _frame_drop = true; + // 该帧最后一个rtp包,输出frame + outputFrame(rtp); + } + return info.keyFrame(); +} + +void VP9RtpDecoder::outputFrame(const RtpPacket::Ptr &rtp) { + if (_frame->dropAble()) { + // 不参与dts生成 [AUTO-TRANSLATED:dff3b747] + // Not involved in dts generation + _frame->_dts = _frame->_pts; + } else { + // rtsp没有dts,那么根据pts排序算法生成dts [AUTO-TRANSLATED:f37c17f3] + // Rtsp does not have dts, so dts is generated according to the pts sorting algorithm + _dts_generator.getDts(_frame->_pts, _frame->_dts); + } + + if (_frame->keyFrame() && _gop_dropped) { + _gop_dropped = false; + InfoL << "new gop received, rtp:\r\n" << rtp->dumpString(); + } + if (!_gop_dropped || _frame->configFrame()) { + // InfoL << _frame->pts() << " size=" << _frame->size(); + RtpCodec::inputFrame(_frame); + } + obtainFrame(); +} + + +//////////////////////////////////////////////////////////////////////// + +bool VP9RtpEncoder::inputFrame(const Frame::Ptr &frame) { + uint8_t header[20] = { 0 }; + int nheader = 1; + header[0] = kBBit; + bool key = frame->keyFrame(); + if (!key) + header[0] |= kPBit; +#if 1 + header[0] |= kIBit; + if (++_pic_id > 0x7FFF) { + _pic_id = 0; + } + header[1] = (0x80 | ((_pic_id >> 8) & 0x7F)); + header[2] = (_pic_id & 0xFF); + nheader += 2; +#endif + const char *ptr = frame->data() + frame->prefixSize(); + int len = frame->size() - frame->prefixSize(); + int pdu_size = getRtpInfo().getMaxSize() - nheader; + + bool mark = false; + for (int pos = 0; pos < len; pos += pdu_size) { + if (len - pos <= pdu_size) { + pdu_size = len - pos; + header[0] |= kEBit; + mark = true; + } + + auto rtp = getRtpInfo().makeRtp(TrackVideo, nullptr, pdu_size + nheader, mark, frame->pts()); + if (rtp) { + uint8_t *payload = rtp->getPayload(); + memcpy(payload, header, nheader); + memcpy(payload + nheader, ptr + pos, pdu_size); + RtpCodec::inputRtp(rtp, key); + } + key = false; + header[0] &= (~kBBit); // Clear 'Begin of partition' bit. + } + return true; +} + +} // namespace mediakit diff --git a/ext-codec/VP9Rtp.h b/ext-codec/VP9Rtp.h new file mode 100644 index 00000000..098366f5 --- /dev/null +++ b/ext-codec/VP9Rtp.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_VP9RTPCODEC_H +#define ZLMEDIAKIT_VP9RTPCODEC_H + +#include "VP9.h" +// for DtsGenerator +#include "Common/Stamp.h" +#include "Rtsp/RtpCodec.h" + +namespace mediakit { + +/** + * VP9 rtp解码类 + * 将 VP9 over rtsp-rtp 解复用出 VP9Frame + */ +class VP9RtpDecoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + VP9RtpDecoder(); + + /** + * 输入VP9 rtp包 + * @param rtp rtp包 + * @param key_pos 此参数忽略之 + */ + bool inputRtp(const RtpPacket::Ptr &rtp, bool key_pos = true) override; + +private: + bool decodeRtp(const RtpPacket::Ptr &rtp); + void outputFrame(const RtpPacket::Ptr &rtp); + void obtainFrame(); +private: + bool _gop_dropped = false; + bool _frame_drop = true; + uint16_t _last_seq = 0; + VP9Frame::Ptr _frame; + DtsGenerator _dts_generator; +}; + +/** + * VP9 rtp打包类 + */ +class VP9RtpEncoder : public RtpCodec { +public: + using Ptr = std::shared_ptr; + + bool inputFrame(const Frame::Ptr &frame) override; +private: + uint16_t _pic_id = 0; +}; + +}//namespace mediakit + +#endif //ZLMEDIAKIT_VP9RTPCODEC_H diff --git a/ext-codec/VpxRtmp.cpp b/ext-codec/VpxRtmp.cpp new file mode 100644 index 00000000..34aedca6 --- /dev/null +++ b/ext-codec/VpxRtmp.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "VpxRtmp.h" +#include "Rtmp/utils.h" +#include "Common/config.h" +#include "Extension/Factory.h" +using namespace std; +using namespace toolkit; + +namespace mediakit { + +void VpxRtmpDecoder::inputRtmp(const RtmpPacket::Ptr &pkt) { + if (_info.codec == CodecInvalid) { + // First, determine if it is an enhanced rtmp + parseVideoRtmpPacket((uint8_t *)pkt->data(), pkt->size(), &_info); + } + + if (_info.is_enhanced) { + // Enhanced rtmp + parseVideoRtmpPacket((uint8_t *)pkt->data(), pkt->size(), &_info); + if (!_info.is_enhanced || _info.codec != getTrack()->getCodecId()) { + throw std::invalid_argument("Invalid enhanced-rtmp packet!"); + } + + auto data = (uint8_t *)pkt->data() + RtmpPacketInfo::kEnhancedRtmpHeaderSize; + auto size = pkt->size() - RtmpPacketInfo::kEnhancedRtmpHeaderSize; + switch (_info.video.pkt_type) { + case RtmpPacketType::PacketTypeSequenceStart: { + getTrack()->setExtraData(data, size); + break; + } + + case RtmpPacketType::PacketTypeCodedFramesX: + case RtmpPacketType::PacketTypeCodedFrames: { + auto pts = pkt->time_stamp; + if (RtmpPacketType::PacketTypeCodedFrames == _info.video.pkt_type) { + CHECK_RET(size > 3); + // SI24 = [CompositionTime Offset] + int32_t cts = (load_be24(data) + 0xff800000) ^ 0xff800000; + pts += cts; + data += 3; + size -= 3; + } + outputFrame((char*)data, size, pkt->time_stamp, pts); + break; + } + default: + WarnL << "Unknown pkt_type: " << (int)_info.video.pkt_type; + break; + } + } else { + CHECK_RET(pkt->size() > 5); + uint8_t *cts_ptr = (uint8_t *)(pkt->buffer.data() + 2); + int32_t cts = (load_be24(cts_ptr) + 0xff800000) ^ 0xff800000; + // 国内扩展(12) Vpx rtmp + if (pkt->isConfigFrame()) { + getTrack()->setExtraData((uint8_t *)pkt->data() + 5, pkt->size() - 5); + } else { + outputFrame(pkt->data() + 5, pkt->size() - 5, pkt->time_stamp, pkt->time_stamp + cts); + } + } +} + +void VpxRtmpDecoder::outputFrame(const char *data, size_t size, uint32_t dts, uint32_t pts) { + RtmpCodec::inputFrame(Factory::getFrameFromPtr(getTrack()->getCodecId(), data, size, dts, pts)); +} + +//////////////////////////////////////////////////////////////////////// +VpxRtmpEncoder::VpxRtmpEncoder(const Track::Ptr &track) : RtmpCodec(track) { + _enhanced = mINI::Instance()[Rtmp::kEnhanced]; +} + +bool VpxRtmpEncoder::inputFrame(const Frame::Ptr &frame) { + auto packet = RtmpPacket::create(); + packet->buffer.resize(8 + frame->size()); + char *buff = packet->data(); + int32_t cts = frame->pts() - frame->dts(); + if (_enhanced) { + auto header = (RtmpVideoHeaderEnhanced *)buff; + header->enhanced = 1; + header->frame_type = frame->keyFrame() ? (int)RtmpFrameType::key_frame : (int)RtmpFrameType::inter_frame; + header->fourcc = htonl(getCodecFourCC(frame->getCodecId())); + buff += RtmpPacketInfo::kEnhancedRtmpHeaderSize; + if (cts) { + header->pkt_type = (uint8_t)RtmpPacketType::PacketTypeCodedFrames; + set_be24(buff, cts); + buff += 3; + } else { + header->pkt_type = (uint8_t)RtmpPacketType::PacketTypeCodedFramesX; + } + } else { + // flags + uint8_t flags = getCodecFlags(frame->getCodecId()); + flags |= (uint8_t)(frame->keyFrame() ? RtmpFrameType::key_frame : RtmpFrameType::inter_frame) << 4; + + buff[0] = flags; + buff[1] = (uint8_t)RtmpH264PacketType::h264_nalu; + // cts + set_be24(&buff[2], cts); + buff += 5; + } + + packet->time_stamp = frame->dts(); + memcpy(buff, frame->data(), frame->size()); + buff += frame->size(); + packet->body_size = buff - packet->data(); + packet->chunk_id = CHUNK_VIDEO; + packet->stream_index = STREAM_MEDIA; + packet->type_id = MSG_VIDEO; + // Output rtmp packet + RtmpCodec::inputRtmp(packet); + return true; +} + +void VpxRtmpEncoder::makeConfigPacket() { + auto extra_data = getTrack()->getExtraData(); + if (!extra_data || !extra_data->size()) + return; + auto pkt = RtmpPacket::create(); + pkt->body_size = 5 + extra_data->size(); + pkt->buffer.resize(pkt->body_size); + auto buff = pkt->buffer.data(); + if (_enhanced) { + auto header = (RtmpVideoHeaderEnhanced *)buff; + header->enhanced = 1; + header->pkt_type = (int)RtmpPacketType::PacketTypeSequenceStart; + header->frame_type = (int)RtmpFrameType::key_frame; + header->fourcc = htonl(getCodecFourCC(getTrack()->getCodecId())); + } else { + uint8_t flags = getCodecFlags(getTrack()->getCodecId()); + flags |= ((uint8_t)RtmpFrameType::key_frame << 4); + buff[0] = flags; + buff[1] = (uint8_t)RtmpH264PacketType::h264_config_header; + // cts + memset(buff + 2, 0, 3); + } + memcpy(buff+5, extra_data->data(), extra_data->size()); + pkt->chunk_id = CHUNK_VIDEO; + pkt->stream_index = STREAM_MEDIA; + pkt->time_stamp = 0; + pkt->type_id = MSG_VIDEO; + RtmpCodec::inputRtmp(pkt); +} + +} // namespace mediakit diff --git a/ext-codec/VpxRtmp.h b/ext-codec/VpxRtmp.h new file mode 100644 index 00000000..e8752efc --- /dev/null +++ b/ext-codec/VpxRtmp.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_VPX_RTMPCODEC_H +#define ZLMEDIAKIT_VPX_RTMPCODEC_H + +#include "Rtmp/RtmpCodec.h" +#include "Extension/Track.h" + +namespace mediakit { +/** + * Rtmp解码类 + * 将 Vpx over rtmp 解复用出 VpxFrame + */ +class VpxRtmpDecoder : public RtmpCodec { +public: + using Ptr = std::shared_ptr; + + VpxRtmpDecoder(const Track::Ptr &track) : RtmpCodec(track) {} + + void inputRtmp(const RtmpPacket::Ptr &rtmp) override; + +protected: + void outputFrame(const char *data, size_t size, uint32_t dts, uint32_t pts); + +protected: + RtmpPacketInfo _info; +}; + +/** + * Rtmp打包类 + */ +class VpxRtmpEncoder : public RtmpCodec { + bool _enhanced = false; +public: + using Ptr = std::shared_ptr; + + VpxRtmpEncoder(const Track::Ptr &track); + + bool inputFrame(const Frame::Ptr &frame) override; + + void makeConfigPacket() override; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_VPX_RTMPCODEC_H diff --git a/player/AudioSRC.cpp b/player/AudioSRC.cpp index 5fb7aa0e..8989a7dc 100644 --- a/player/AudioSRC.cpp +++ b/player/AudioSRC.cpp @@ -26,7 +26,7 @@ void AudioSRC::setOutputAudioConfig(const SDL_AudioSpec &cfg) { int format = _delegate->getPCMFormat(); int channels = _delegate->getPCMChannel(); if (-1 == SDL_BuildAudioCVT(&_audio_cvt, format, channels, freq, cfg.format, cfg.channels, cfg.freq)) { - throw std::runtime_error("the format conversion is not supported"); + throw std::runtime_error("the format conversion is not supported, " + string(SDL_GetError())); } InfoL << "audio cvt origin format, freq:" << freq << ", format:" << hex << format << dec << ", channels:" << channels; InfoL << "audio cvt info, " diff --git a/player/CMakeLists.txt b/player/CMakeLists.txt index e36b2a46..e255f139 100644 --- a/player/CMakeLists.txt +++ b/player/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/player/SDLAudioDevice.cpp b/player/SDLAudioDevice.cpp index 14f74add..709c95e4 100644 --- a/player/SDLAudioDevice.cpp +++ b/player/SDLAudioDevice.cpp @@ -18,7 +18,10 @@ using namespace toolkit; INSTANCE_IMP(SDLAudioDevice); SDLAudioDevice::~SDLAudioDevice() { - SDL_CloseAudio(); + if (_device) { + SDL_CloseAudioDevice(_device); + _device = 0; + } } SDLAudioDevice::SDLAudioDevice() { @@ -33,9 +36,13 @@ SDLAudioDevice::SDLAudioDevice() { SDLAudioDevice *_this = (SDLAudioDevice *) userdata; _this->onReqPCM((char *) stream, len); }; - if (SDL_OpenAudioDevice(NULL, 0, &wanted_spec, &_audio_config, SDL_AUDIO_ALLOW_ANY_CHANGE) < 0) { - throw std::runtime_error("SDL_OpenAudioDevice failed"); - } + + _device = SDL_OpenAudioDevice(NULL, 0, &wanted_spec, &_audio_config, 0); + if (_device <= 0) + _device = SDL_OpenAudioDevice(NULL, 0, &wanted_spec, &_audio_config, SDL_AUDIO_ALLOW_ANY_CHANGE); + if (_device <= 0) { + throw std::runtime_error("SDL_OpenAudioDevice failed"); + } InfoL << "actual audioSpec, " << "freq:" << _audio_config.freq << ", format:" << hex << _audio_config.format << dec @@ -51,7 +58,7 @@ SDLAudioDevice::SDLAudioDevice() { void SDLAudioDevice::addChannel(AudioSRC *chn) { lock_guard lck(_channel_mtx); if (_channels.empty()) { - SDL_PauseAudio(0); + SDL_PauseAudioDevice(_device, false); } chn->setOutputAudioConfig(_audio_config); _channels.emplace(chn); @@ -61,7 +68,7 @@ void SDLAudioDevice::delChannel(AudioSRC *chn) { lock_guard lck(_channel_mtx); _channels.erase(chn); if (_channels.empty()) { - SDL_PauseAudio(true); + SDL_PauseAudioDevice(_device, true); } } diff --git a/player/SDLAudioDevice.h b/player/SDLAudioDevice.h index bddea9cb..69320995 100644 --- a/player/SDLAudioDevice.h +++ b/player/SDLAudioDevice.h @@ -40,6 +40,7 @@ private: void onReqPCM(char *stream, int len); private: + SDL_AudioDeviceID _device; std::shared_ptr _play_buf; SDL_AudioSpec _audio_config; std::recursive_mutex _channel_mtx; diff --git a/player/YuvDisplayer.h b/player/YuvDisplayer.h index 54702cec..318baa0f 100644 --- a/player/YuvDisplayer.h +++ b/player/YuvDisplayer.h @@ -135,16 +135,27 @@ public: } bool displayYUV(AVFrame *pFrame){ if (!_win) { + int w, h; + double hw = 0.0f; + w = pFrame->width; + h = pFrame->height; + hw = (double)h / (double)w; + w = 720; + h = w * hw; if (_hwnd) { _win = SDL_CreateWindowFrom(_hwnd); }else { _win = SDL_CreateWindow(_title.data(), SDL_WINDOWPOS_UNDEFINED, SDL_WINDOWPOS_UNDEFINED, - pFrame->width, - pFrame->height, - SDL_WINDOW_OPENGL); + w, + h, + SDL_WINDOW_OPENGL |SDL_WINDOW_RESIZABLE | SDL_WINDOW_SHOWN); // 允许最大化 } + SDL_SetWindowInputFocus(_win); + SDL_RaiseWindow(_win); + // SDL_GL_SetSwapInterval(1); // 1 ,“开启垂直同步”就是让程序“等显示器”,以牺牲一点延迟换取画面完整无撕裂。 + } if (_win && ! _render){ #if 0 diff --git a/player/test_player.cpp b/player/test_player.cpp index d0cb8afb..9904dfd5 100644 --- a/player/test_player.cpp +++ b/player/test_player.cpp @@ -41,6 +41,10 @@ int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstanc, LPSTR lpCmdLine, freopen_s(&stream, "CON", "r", stdin);//重定向输入流 freopen_s(&stream, "CON", "w", stdout);//重定向输入流 + // 清除流缓冲区, 在win11上还是无法输出文字,需要在加入如下代码 + std::cin.clear(); + std::cout.clear(); + //3. 如果我们需要用到控制台窗口句柄,可以调用FindWindow取得: HWND _consoleHwnd; SetConsoleTitleA("test_player");//设置窗口名 @@ -56,8 +60,8 @@ int main(int argc, char *argv[]) { Logger::Instance().add(std::make_shared()); Logger::Instance().setWriter(std::make_shared()); - if (argc < 3) { - ErrorL << "\r\n测试方法:./test_player rtxp_url rtp_type\r\n" + if (argc < 2) { + ErrorL << "\r\n测试方法:./test_player rtxp_url [rtp_type] [play_track]\r\n" << "例如:./test_player rtsp://admin:123456@127.0.0.1/live/0 0\r\n"; return 0; } @@ -97,10 +101,14 @@ int main(int argc, char *argv[]) { decoder->setOnDecode([audio_player, swr](const FFmpegFrame::Ptr &frame) mutable { if (!swr) { +# if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + swr = std::make_shared(AV_SAMPLE_FMT_S16, &(frame->get()->ch_layout), frame->get()->sample_rate); +#else swr = std::make_shared(AV_SAMPLE_FMT_S16, frame->get()->channels, frame->get()->channel_layout, frame->get()->sample_rate); +#endif } auto pcm = swr->inputFrame(frame); - auto len = pcm->get()->nb_samples * pcm->get()->channels * av_get_bytes_per_sample((enum AVSampleFormat)pcm->get()->format); + auto len = pcm->get()->nb_samples * pcm->getChannels() * av_get_bytes_per_sample((enum AVSampleFormat)pcm->get()->format); audio_player->playPCM((const char *)(pcm->get()->data[0]), MIN(len, frame->get()->linesize[0])); }); audioTrack->addDelegate([decoder](const Frame::Ptr &frame) { return decoder->inputFrame(frame, false, true); }); @@ -108,10 +116,11 @@ int main(int argc, char *argv[]) { }); player->setOnShutdown([](const SockException &ex) { WarnL << "play shutdown: " << ex.what(); }); - - (*player)[Client::kRtpType] = atoi(argv[2]); // 不等待track ready再回调播放成功事件,这样可以加快秒开速度 (*player)[Client::kWaitTrackReady] = false; + if (argc > 2) { + (*player)[Client::kRtpType] = atoi(argv[2]); + } if (argc > 3) { (*player)[Client::kPlayTrack] = atoi(argv[3]); } diff --git a/postman/ZLMediaKit.postman_collection.json b/postman/ZLMediaKit.postman_collection.json index a3da5a65..c7d1b26b 100644 --- a/postman/ZLMediaKit.postman_collection.json +++ b/postman/ZLMediaKit.postman_collection.json @@ -39,14 +39,15 @@ "method": "GET", "header": [], "url": { - "raw": "{{ZLMediaKit_URL}}/index/api/getApiList?secret={{ZLMediaKit_secret}}&id=stack_test", + "raw": "{{ZLMediaKit_URL}}/index/api/stack/stop?secret={{ZLMediaKit_secret}}&id=stack_test", "host": [ "{{ZLMediaKit_URL}}" ], "path": [ "index", "api", - "getApiList" + "stack", + "stop" ], "query": [ { @@ -56,7 +57,44 @@ }, { "key": "id", - "value": "stack_test" + "value": "stack_test", + "description": "多屏拼接id" + } + ] + } + }, + "response": [] + }, + { + "name": "重置多屏拼接(stack/reset)", + "request": { + "method": "POST", + "header": [], + "body": { + "mode": "raw", + "raw": "{\r\n \"gapv\": 0.002,\r\n \"gaph\": 0.001,\r\n \"width\": 1920,\r\n \"url\": [\r\n [\r\n \"rtsp://kkem.me/live/test3\",\r\n \"rtsp://kkem.me/live/cy1\",\r\n \"rtsp://kkem.me/live/cy1\",\r\n \"rtsp://kkem.me/live/cy2\"\r\n ],\r\n [\r\n \"rtsp://kkem.me/live/cy1\",\r\n \"rtsp://kkem.me/live/cy5\",\r\n \"rtsp://kkem.me/live/cy3\",\r\n \"rtsp://kkem.me/live/cy4\"\r\n ],\r\n [\r\n \"rtsp://kkem.me/live/cy5\",\r\n \"rtsp://kkem.me/live/cy6\",\r\n \"rtsp://kkem.me/live/cy7\",\r\n \"rtsp://kkem.me/live/cy8\"\r\n ],\r\n [\r\n \"rtsp://kkem.me/live/cy9\",\r\n \"rtsp://kkem.me/live/cy10\",\r\n \"rtsp://kkem.me/live/cy11\",\r\n \"rtsp://kkem.me/live/cy12\"\r\n ]\r\n ],\r\n \"id\": \"89\",\r\n \"row\": 4,\r\n \"col\": 4,\r\n \"height\": 1080,\r\n \"span\": [\r\n [\r\n [\r\n 0,\r\n 0\r\n ],\r\n [\r\n 1,\r\n 1\r\n ]\r\n ],\r\n [\r\n [\r\n 3,\r\n 0\r\n ],\r\n [\r\n 3,\r\n 1\r\n ]\r\n ],\r\n [\r\n [\r\n 2,\r\n 3\r\n ],\r\n [\r\n 3,\r\n 3\r\n ]\r\n ]\r\n ]\r\n}", + "options": { + "raw": { + "language": "json" + } + } + }, + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/stack/reset?secret={{ZLMediaKit_secret}}", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "stack", + "reset" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}", + "description": "api操作密钥(配置文件配置)" } ] } @@ -310,6 +348,53 @@ }, "response": [] }, + { + "name": "删除截图(deleteSnapDirectory)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/deleteSnapDirectory?secret={{ZLMediaKit_secret}}&vhost={{defaultVhost}}&app=live&stream=test&file=71_1740828613.jpg", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "deleteSnapDirectory" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}", + "description": "api操作密钥(配置文件配置)" + }, + { + "key": "vhost", + "value": "{{defaultVhost}}", + "description": "筛选虚拟主机,例如__defaultVhost__" + }, + { + "key": "app", + "value": "live", + "description": "筛选应用名,例如 live" + }, + { + "key": "stream", + "value": "test", + "description": "筛选流id,例如 test" + }, + { + "key": "file", + "value": "", + "disabled": true, + "description": "文件名,非必选" + } + ] + } + }, + "response": [] + }, { "name": "关断单个流(close_stream)", "request": { @@ -522,7 +607,7 @@ "response": [] }, { - "name": "添加rtsp/rtmp/hls/srt拉流代理(addStreamProxy)", + "name": "添加拉流代理(addStreamProxy)", "request": { "method": "GET", "header": [], @@ -560,7 +645,7 @@ { "key": "url", "value": "rtmp://live.hkstv.hk.lxdns.com/live/hks2", - "description": "拉流地址,例如rtmp://live.hkstv.hk.lxdns.com/live/hks2" + "description": "拉流地址,支持rtsp/rtmp/hls/srt/http-flv/http-ts协议" }, { "key": "rtp_type", @@ -828,6 +913,12 @@ "description": "推流重试次数,不传此参数或传值<=0时,则无限重试", "disabled": true }, + { + "key": "force", + "value": null, + "description": "是否强制添加代理,默认0,设置为1时如果拉流失败也会不断重试", + "disabled": true + }, { "key": "latency", "value": null, @@ -1179,19 +1270,19 @@ "response": [] }, { - "name": "获取流信息(getMp4RecordFile)", + "name": "获取录像文件列表(getMP4RecordFile)", "request": { "method": "GET", "header": [], "url": { - "raw": "{{ZLMediaKit_URL}}/index/api/getMp4RecordFile?secret={{ZLMediaKit_secret}}&vhost={{defaultVhost}}&app=proxy&stream=2&customized_path=/www&period=2020-05-26", + "raw": "{{ZLMediaKit_URL}}/index/api/getMP4RecordFile?secret={{ZLMediaKit_secret}}&vhost={{defaultVhost}}&app=proxy&stream=2&customized_path=/www&period=2020-05-26", "host": [ "{{ZLMediaKit_URL}}" ], "path": [ "index", "api", - "getMp4RecordFile" + "getMP4RecordFile" ], "query": [ { @@ -1333,6 +1424,62 @@ }, "response": [] }, + { + "name": "开始事件视频录制(startRecordTask)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/startRecordTask?secret={{ZLMediaKit_secret}}&vhost={{defaultVhost}}&app=live&stream=test&path=1.mp4&back_ms=10000&forward_ms=10000", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "startRecordTask" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}", + "description": "api操作密钥(配置文件配置)" + }, + { + "key": "vhost", + "value": "{{defaultVhost}}", + "description": "虚拟主机,例如__defaultVhost__" + }, + { + "key": "app", + "value": "live", + "description": "应用名,例如 live" + }, + { + "key": "stream", + "value": "test", + "description": "流id,例如 obs" + }, + { + "key": "path", + "value": "1.mp4", + "description": "录像文件保存相对路径,包括名称" + }, + { + "key": "back_ms", + "value": "10000", + "description": "回溯录制时长" + }, + { + "key": "forward_ms", + "value": "10000", + "description": "后续录制时长" + } + ] + } + }, + "response": [] + }, { "name": "设置录像速度(setRecordSpeed)", "request": { @@ -1552,6 +1699,12 @@ "key": "expire_sec", "value": "1", "description": "截图的过期时间,该时间内产生的截图都会作为缓存返回" + }, + { + "key": "async", + "value": "0", + "disabled": true, + "description": "是否采用zlm内置播放器、解码器api异步截图,开启后截图速度提升但兼容性降低" } ] } @@ -1910,6 +2063,12 @@ "key": "stream_id", "value": "test", "description": "该端口绑定的流id" + }, + { + "key": "pause_seconds", + "value": "300", + "description": "暂停超时监测后,将在pause_seconds时间后恢复", + "disabled": true } ] } @@ -2086,6 +2245,12 @@ "value": "", "description": "发送rtp同时接收,一般用于双向语言对讲, 如果不为空,说明开启接收,值为接收流的id", "disabled": true + }, + { + "key": "enable_origin_recv_limit", + "value": "1", + "description": "转发rtp(tcp模式)时,如果发送不出去,是否限制源端收流速度,此参数在多倍速rtp转发时作用较大", + "disabled": true } ] } @@ -2180,6 +2345,12 @@ "value": "5000", "description": "等待tcp连接超时时间,单位毫秒,默认5000毫秒", "disabled": true + }, + { + "key": "enable_origin_recv_limit", + "value": "1", + "description": "转发rtp(tcp模式)时,如果发送不出去,是否限制源端收流速度,此参数在多倍速rtp转发时作用较大", + "disabled": true } ] } @@ -2255,6 +2426,12 @@ "value": "1", "description": "rtp es方式打包时,是否只打包音频;该参数非必选参数", "disabled": true + }, + { + "key": "enable_origin_recv_limit", + "value": "1", + "description": "转发rtp(tcp模式)时,如果发送不出去,是否限制源端收流速度,此参数在多倍速rtp转发时作用较大", + "disabled": true } ] } @@ -2479,6 +2656,18 @@ "description": "是否循环点播mp4文件,如果配置文件已经开启循环点播,此参数无效", "disabled": true }, + { + "key": "seek_ms", + "value": "0", + "description": "点播seek到特定位置,单位毫秒", + "disabled": true + }, + { + "key": "speed", + "value": "1.0", + "description": "播放速度, float类型", + "disabled": true + }, { "key": "enable_hls", "value": "", @@ -2599,7 +2788,498 @@ } }, "response": [] - } + }, + { + "name": "WebRTC-注册到信令服务器(addWebrtcRoomKeeper)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/addWebrtcRoomKeeper?secret={{ZLMediaKit_secret}}&server_host=127.0.0.1&server_port=3000&room_id=peer_1", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "addWebrtcRoomKeeper" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + }, + { + "key": "server_host", + "value": "127.0.0.1", + "description": "要注册到的信令服务器地址" + }, + { + "key": "server_port", + "value": "3000", + "description": "要注册到的信令服务器端口" + }, + { + "key": "room_id", + "value": "peer_1", + "description": "要注册到的roomid" + } + ] + } + }, + "response": [] + }, + { + "name": "WebRTC-从信令服务器注销(delWebrtcRoomKeeper)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/delWebrtcRoomKeeper?secret={{ZLMediaKit_secret}}&room_key=", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "delWebrtcRoomKeeper" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + }, + { + "key": "room_key", + "value": "" + } + ] + } + }, + "response": [] + }, + { + "name": "WebRTC-Peer查看注册信息(listWebrtcRoomKeepers)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/listWebrtcRoomKeepers?secret={{ZLMediaKit_secret}}", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "listWebrtcRoomKeepers" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + } + ] + } + }, + "response": [] + }, + { + "name": "WebRTC-信令服务器查看注册信息(listWebrtcRooms)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/listWebrtcRooms?secret={{ZLMediaKit_secret}}", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "listWebrtcRooms" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + } + ] + } + }, + "response": [] + }, + { + "name": "WebRTC-查看WebRTCProxyPlayer连接信息(getWebrtcProxyPlayerInfo)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/getWebrtcProxyPlayerInfo?secret={{ZLMediaKit_secret}}&key=__defaultVhost__/live/test", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "getWebrtcProxyPlayerInfo" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + }, + { + "key": "key", + "value": "__defaultVhost__/live/test" + } + ] + } + }, + "response": [] + }, + { + "name": "onvif 搜索", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/searchOnvifDevice?secret={{ZLMediaKit_secret}}&timeout_ms=5000", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "searchOnvifDevice" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + }, + { + "key": "subnet_prefix", + "value": "192.168.1" + } + ] + } + }, + "response": [] + }, + { + "name": "获取 onvif 设备url", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/getStreamUrl?secret={{ZLMediaKit_secret}}&onvif_url=http://xxxx/onvif/device_service", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "getStreamUrl" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}" + }, + { + "key": "onvif_url", + "value": "http://xxxx/onvif/device_service" + } + ] + } + }, + "response": [] + }, + { + "name": "下载程序二进制文件(downloadBin)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/downloadBin?secret={{ZLMediaKit_secret}}", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "downloadBin" + ], + "query": [ + { + "key": "secret", + "value": "{{ZLMediaKit_secret}}", + "description": "api操作密钥(配置文件配置)" + } + ] + } + }, + "response": [] + }, + { + "name": "WebRTC交互(webrtc)", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/json" + } + ], + "body": { + "mode": "raw", + "raw": "" + }, + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/webrtc?secret={{ZLMediaKit_secret}}&type=play&app=live&stream=test", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "webrtc" + ], + "query": [ + { + "key": "type", + "value": "play", + "description": "webrtc类型,play为播放,push为推流,echo为回显测试" + }, + { + "key": "app", + "value": "live", + "description": "应用名" + }, + { + "key": "stream", + "value": "test", + "description": "流id" + }, + { + "key": "preferred_tcp", + "value": null, + "description": "是否webrtc over tcp优先模式", + "disabled": true + }, + { + "key": "cand_udp", + "value": "test", + "description": "指定zlm服务器udp candidate", + "disabled": true + }, + { + "key": "cand_tcp", + "value": null, + "description": "指定zlm服务器tcp candidate", + "disabled": true + } + ] + }, + "description": "WebRTC交互接口,body为SDP offer" + }, + "response": [] + }, + { + "name": "WebRTC-WHIP推流(whip)", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/sdp" + } + ], + "body": { + "mode": "raw", + "raw": "" + }, + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/whip?app=live&stream=test", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "whip" + ], + "query": [ + { + "key": "app", + "value": "live", + "description": "应用名" + }, + { + "key": "stream", + "value": "test", + "description": "流id" + }, + { + "key": "preferred_tcp", + "value": null, + "description": "是否webrtc over tcp优先模式", + "disabled": true + }, + { + "key": "cand_udp", + "value": "test", + "description": "指定zlm服务器udp candidate", + "disabled": true + }, + { + "key": "cand_tcp", + "value": null, + "description": "指定zlm服务器tcp candidate", + "disabled": true + } + ] + }, + "description": "WebRTC WHIP标准推流接口,body为SDP offer" + }, + "response": [] + }, + { + "name": "WebRTC-WHEP播放(whep)", + "request": { + "method": "POST", + "header": [ + { + "key": "Content-Type", + "value": "application/sdp" + } + ], + "body": { + "mode": "raw", + "raw": "" + }, + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/whep?app=live&stream=test", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "whep" + ], + "query": [ + { + "key": "app", + "value": "live", + "description": "应用名" + }, + { + "key": "stream", + "value": "test", + "description": "流id" + }, + { + "key": "preferred_tcp", + "value": null, + "description": "是否webrtc over tcp优先模式", + "disabled": true + }, + { + "key": "cand_udp", + "value": "test", + "description": "指定zlm服务器udp candidate", + "disabled": true + }, + { + "key": "cand_tcp", + "value": null, + "description": "指定zlm服务器tcp candidate", + "disabled": true + } + ] + }, + "description": "WebRTC WHEP标准播放接口,body为SDP offer" + }, + "response": [] + }, + { + "name": "WebRTC-删除连接(delete_webrtc)", + "request": { + "method": "DELETE", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/delete_webrtc?id=&token=", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "delete_webrtc" + ], + "query": [ + { + "key": "id", + "value": "", + "description": "WebRTC连接的唯一标识" + }, + { + "key": "token", + "value": "", + "description": "删除操作的验证token" + } + ] + }, + "description": "删除WebRTC连接,需要使用DELETE方法。id和token由whip/whep接口返回的Location头中获取。" + }, + "response": [] + }, + { + "name": "登录(login)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/login?digest=d00414822dfd8eabed87c5e24ffcdca7", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "login" + ], + "query": [ + { + "key": "digest", + "value": "", + "description": "MD5(\"zlmediakit:\"+${secret}+\":\" +${cookie})" + } + ] + } + }, + "response": [] + }, + { + "name": "登出(logout)", + "request": { + "method": "GET", + "header": [], + "url": { + "raw": "{{ZLMediaKit_URL}}/index/api/logout", + "host": [ + "{{ZLMediaKit_URL}}" + ], + "path": [ + "index", + "api", + "logout" + ] + } + }, + "response": [] + } ], "event": [ { diff --git a/resource.rc b/resource.rc new file mode 100644 index 00000000..8ad6bd07 --- /dev/null +++ b/resource.rc @@ -0,0 +1,48 @@ +#ifdef APSTUDIO_INVOKED +#error "This file is not editable by Visual C++." +#endif //APSTUDIO_INVOKED + +#include "winres.h" + +#if defined(ENABLE_VERSION) +#include "ZLMVersion.h" +#endif + +#define ZLM_VERSION 8,0,0,1 + +// 拼接 BRANCH_NAME 和 COMMIT_HASH ,例如 master - 1c8ed1c +#define COMMIT_HASH_BRANCH_STR BRANCH_NAME " - " COMMIT_HASH + +IDI_ICON1 ICON DISCARDABLE "www//logo.ico" + +VS_VERSION_INFO VERSIONINFO + FILEVERSION ZLM_VERSION + PRODUCTVERSION ZLM_VERSION + FILEFLAGSMASK 0x17L +#ifdef _DEBUG + FILEFLAGS 0x1L +#else + FILEFLAGS 0x0L +#endif + FILEOS 0x4L + FILETYPE 0x2L + FILESUBTYPE 0x0L +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "000004b0" + BEGIN + VALUE "CompanyName", "Applied ZLMediaKit Informatics Software" + VALUE "FileDescription", "This file is part of the C++ ZLM" + VALUE "FileVersion", COMMIT_HASH_BRANCH_STR + VALUE "InternalName", COMMIT_HASH_BRANCH_STR + VALUE "LegalCopyright", "Copyright (c) 2016-present The ZLMediaKit project authors" + VALUE "ProductName", "https://github.com/ZLMediaKit" + VALUE "ProductVersion", COMMIT_HASH_BRANCH_STR + END + END + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x0, 1200 + END +END diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 63abc05d..9f063e40 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -50,10 +50,43 @@ target_compile_definitions(MediaServer target_compile_options(MediaServer PRIVATE ${COMPILE_OPTIONS_DEFAULT}) +if(MINGW) + update_cached_list(MK_LINK_LIBRARIES dbghelp) +endif() + if(CMAKE_SYSTEM_NAME MATCHES "Linux") target_link_libraries(MediaServer -Wl,--start-group ${MK_LINK_LIBRARIES} -Wl,--end-group) else() target_link_libraries(MediaServer ${MK_LINK_LIBRARIES}) endif() +if(MSVC) + set(RESOURCE_FILE "${CMAKE_SOURCE_DIR}/resource.rc") + set_source_files_properties(${RESOURCE_FILE} PROPERTIES LANGUAGE RC) + target_sources(MediaServer PRIVATE ${RESOURCE_FILE}) +else() + # Android, IOS, macOS ... + # CLion, GCC ... +endif() + install(TARGETS MediaServer DESTINATION ${INSTALL_PATH_RUNTIME}) + +#relase 类型时额外输出debug调试信息 +string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER) +if(UNIX AND ENABLE_OBJCOPY) + if("${CMAKE_BUILD_TYPE_LOWER}" STREQUAL "release") + find_program(OBJCOPY_FOUND objcopy) + if (OBJCOPY_FOUND) + add_custom_command(TARGET MediaServer + POST_BUILD + COMMAND objcopy --only-keep-debug ${EXECUTABLE_OUTPUT_PATH}/MediaServer ${EXECUTABLE_OUTPUT_PATH}/MediaServer.debug + COMMAND objcopy --strip-all ${EXECUTABLE_OUTPUT_PATH}/MediaServer + COMMAND objcopy --add-gnu-debuglink=${EXECUTABLE_OUTPUT_PATH}/MediaServer.debug ${EXECUTABLE_OUTPUT_PATH}/MediaServer + ) + install(FILES ${EXECUTABLE_OUTPUT_PATH}/MediaServer.debug DESTINATION ${INSTALL_PATH_RUNTIME}) + else() + message(STATUS "not found objcopy, generate MediaServer.debug skip") + endif() + endif() +endif() + diff --git a/server/FFmpegSource.cpp b/server/FFmpegSource.cpp index 09b16e50..4300ae3e 100644 --- a/server/FFmpegSource.cpp +++ b/server/FFmpegSource.cpp @@ -84,86 +84,92 @@ void FFmpegSource::play(const string &ffmpeg_cmd_key, const string &src_url, con try { _media_info.parse(dst_url); - } catch (std::exception &ex) { - cb(SockException(Err_other, ex.what())); - return; - } - auto ffmpeg_cmd = ffmpeg_cmd_default; - if (!ffmpeg_cmd_key.empty()) { - auto cmd_it = mINI::Instance().find(ffmpeg_cmd_key); - if (cmd_it != mINI::Instance().end()) { - ffmpeg_cmd = cmd_it->second; + auto ffmpeg_cmd = ffmpeg_cmd_default; + if (!ffmpeg_cmd_key.empty()) { + auto cmd_it = mINI::Instance().find(ffmpeg_cmd_key); + if (cmd_it != mINI::Instance().end()) { + ffmpeg_cmd = cmd_it->second; + } else { + WarnL << "配置文件中,ffmpeg命令模板(" << ffmpeg_cmd_key << ")不存在,已采用默认模板(" << ffmpeg_cmd_default << ")"; + } + } + if (!toolkit::start_with(ffmpeg_cmd, "%s")) { + throw std::invalid_argument("ffmpeg cmd template must start with '%s'"); + } + + char cmd[2048] = { 0 }; + snprintf(cmd, sizeof(cmd), ffmpeg_cmd.data(), File::absolutePath("", ffmpeg_bin).data(), src_url.data(), dst_url.data()); + auto log_file = ffmpeg_log.empty() ? "" : File::absolutePath("", ffmpeg_log); + _process.run(cmd, log_file); + _cmd = cmd; + InfoL << cmd; + + if (is_local_ip(_media_info.host)) { + // 推流给自己的,通过判断流是否注册上来判断是否正常 [AUTO-TRANSLATED:423f2be6] + // Push stream to yourself, judge whether the stream is registered to determine whether it is normal + if (_media_info.schema != RTSP_SCHEMA && _media_info.schema != RTMP_SCHEMA && _media_info.schema != "srt") { + cb(SockException(Err_other, "本服务只支持rtmp/rtsp/srt推流")); + return; + } + weak_ptr weakSelf = shared_from_this(); + findAsync(timeout_ms, [cb, weakSelf, timeout_ms](const MediaSource::Ptr &src) { + auto strongSelf = weakSelf.lock(); + if (!strongSelf) { + // 自己已经销毁 [AUTO-TRANSLATED:3d45c3b0] + // Self has been destroyed + return; + } + if (src) { + // 推流给自己成功 [AUTO-TRANSLATED:65dba71b] + // Push stream to yourself successfully + cb(SockException()); + strongSelf->onGetMediaSource(src); + strongSelf->startTimer(timeout_ms); + return; + } + // 推流失败 [AUTO-TRANSLATED:4d8d226a] + // Push stream failed + if (!strongSelf->_process.wait(false)) { + // ffmpeg进程已经退出 [AUTO-TRANSLATED:04193893] + // ffmpeg process has exited + cb(SockException(Err_other, StrPrinter << "ffmpeg已经退出,exit code = " << strongSelf->_process.exit_code())); + return; + } + // ffmpeg进程还在线,但是等待推流超时 [AUTO-TRANSLATED:9f71f17b] + // ffmpeg process is still online, but waiting for the stream to timeout + cb(SockException(Err_other, "等待超时")); + }); } else { - WarnL << "配置文件中,ffmpeg命令模板(" << ffmpeg_cmd_key << ")不存在,已采用默认模板(" << ffmpeg_cmd_default << ")"; + // 推流给其他服务器的,通过判断FFmpeg进程是否在线判断是否成功 [AUTO-TRANSLATED:9b963da5] + // Push stream to other servers, judge whether it is successful by judging whether the FFmpeg process is online + weak_ptr weakSelf = shared_from_this(); + _timer = std::make_shared( + timeout_ms / 1000.0f, + [weakSelf, cb, timeout_ms]() { + auto strongSelf = weakSelf.lock(); + if (!strongSelf) { + // 自身已经销毁 [AUTO-TRANSLATED:5f954f8a] + // Self has been destroyed + return false; + } + // FFmpeg还在线,那么我们认为推流成功 [AUTO-TRANSLATED:4330df49] + // FFmpeg is still online, so we think the push stream is successful + if (strongSelf->_process.wait(false)) { + cb(SockException()); + strongSelf->startTimer(timeout_ms); + return false; + } + // ffmpeg进程已经退出 [AUTO-TRANSLATED:04193893] + // ffmpeg process has exited + cb(SockException(Err_other, StrPrinter << "ffmpeg已经退出,exit code = " << strongSelf->_process.exit_code())); + return false; + }, + _poller); } - } - - char cmd[2048] = { 0 }; - snprintf(cmd, sizeof(cmd), ffmpeg_cmd.data(), File::absolutePath("", ffmpeg_bin).data(), src_url.data(), dst_url.data()); - auto log_file = ffmpeg_log.empty() ? "" : File::absolutePath("", ffmpeg_log); - _process.run(cmd, log_file); - _cmd = cmd; - InfoL << cmd; - - if (is_local_ip(_media_info.host)) { - // 推流给自己的,通过判断流是否注册上来判断是否正常 [AUTO-TRANSLATED:423f2be6] - // Push stream to yourself, judge whether the stream is registered to determine whether it is normal - if (_media_info.schema != RTSP_SCHEMA && _media_info.schema != RTMP_SCHEMA) { - cb(SockException(Err_other, "本服务只支持rtmp/rtsp推流")); - return; - } - weak_ptr weakSelf = shared_from_this(); - findAsync(timeout_ms, [cb, weakSelf, timeout_ms](const MediaSource::Ptr &src) { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) { - // 自己已经销毁 [AUTO-TRANSLATED:3d45c3b0] - // Self has been destroyed - return; - } - if (src) { - // 推流给自己成功 [AUTO-TRANSLATED:65dba71b] - // Push stream to yourself successfully - cb(SockException()); - strongSelf->onGetMediaSource(src); - strongSelf->startTimer(timeout_ms); - return; - } - // 推流失败 [AUTO-TRANSLATED:4d8d226a] - // Push stream failed - if (!strongSelf->_process.wait(false)) { - // ffmpeg进程已经退出 [AUTO-TRANSLATED:04193893] - // ffmpeg process has exited - cb(SockException(Err_other, StrPrinter << "ffmpeg已经退出,exit code = " << strongSelf->_process.exit_code())); - return; - } - // ffmpeg进程还在线,但是等待推流超时 [AUTO-TRANSLATED:9f71f17b] - // ffmpeg process is still online, but waiting for the stream to timeout - cb(SockException(Err_other, "等待超时")); - }); - } else{ - // 推流给其他服务器的,通过判断FFmpeg进程是否在线判断是否成功 [AUTO-TRANSLATED:9b963da5] - // Push stream to other servers, judge whether it is successful by judging whether the FFmpeg process is online - weak_ptr weakSelf = shared_from_this(); - _timer = std::make_shared(timeout_ms / 1000.0f, [weakSelf, cb, timeout_ms]() { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) { - // 自身已经销毁 [AUTO-TRANSLATED:5f954f8a] - // Self has been destroyed - return false; - } - // FFmpeg还在线,那么我们认为推流成功 [AUTO-TRANSLATED:4330df49] - // FFmpeg is still online, so we think the push stream is successful - if (strongSelf->_process.wait(false)) { - cb(SockException()); - strongSelf->startTimer(timeout_ms); - return false; - } - // ffmpeg进程已经退出 [AUTO-TRANSLATED:04193893] - // ffmpeg process has exited - cb(SockException(Err_other, StrPrinter << "ffmpeg已经退出,exit code = " << strongSelf->_process.exit_code())); - return false; - }, _poller); + } catch (std::exception &ex) { + WarnL << ex.what(); + cb(SockException(Err_other, ex.what())); } } @@ -341,15 +347,70 @@ void FFmpegSource::onGetMediaSource(const MediaSource::Ptr &src) { setDelegate(listener); muxer->setDelegate(shared_from_this()); if (_enable_hls) { - src->setupRecord(Recorder::type_hls, true, "", 0); + src->getOwnerPoller()->async([=]() mutable { + src->setupRecord(Recorder::type_hls, true, "", 0); + }); } if (_enable_mp4) { - src->setupRecord(Recorder::type_mp4, true, "", 0); + src->getOwnerPoller()->async([=]() mutable { + src->setupRecord(Recorder::type_mp4, true, "", 0); + }); } } } -void FFmpegSnap::makeSnap(const string &play_url, const string &save_path, float timeout_sec, const onSnap &cb) { +#if defined(ENABLE_FFMPEG) +#include "Player/MediaPlayer.h" +#include "Codec/Transcode.h" + +static void makeSnapAsync(const string &play_url, const string &save_path, float timeout_sec, const FFmpegSnap::onSnap &cb) { + struct Holder { + MediaPlayer::Ptr player; + }; + auto holder = std::make_shared(); + auto player = std::make_shared(); + (*player)[mediakit::Client::kTimeoutMS] = timeout_sec * 1000; + + player->setOnPlayResult([holder, save_path, cb, timeout_sec](const SockException &ex) mutable { + onceToken token(nullptr, [&]() { holder->player = nullptr; }); + auto video = ex ? nullptr : dynamic_pointer_cast(holder->player->getTrack(TrackVideo, false)); + if (!video) { + cb(false, ex ? ex.what() : "none video track"); + return; + } + auto decoder = std::make_shared(video); + auto new_holder = std::make_shared(*holder); + auto timer = EventPollerPool::Instance().getPoller()->doDelayTask(1000 * timeout_sec, [cb, new_holder]() { + // 防止解码失败导致播放器无法释放 + new_holder->player = nullptr; + cb(false, "decode frame timeout"); + return 0; + }); + auto done = false; + decoder->setOnDecode([save_path, new_holder, cb, done, timer](const FFmpegFrame::Ptr &frame) mutable { + if (done) { + return; + } + onceToken token(nullptr, [&]() { new_holder->player = nullptr; timer->cancel(); done = true; }); + auto ret = FFmpegUtils::saveFrame(frame, save_path.data()); + cb(std::get<0>(ret), std::get<1>(ret)); + }); + video->addDelegate([decoder](const Frame::Ptr &frame) { return decoder->inputFrame(frame, false, true); }); + }); + player->play(play_url); + holder->player = std::move(player); +} + +#endif + +void FFmpegSnap::makeSnap(bool async, const string &play_url, const string &save_path, float timeout_sec, const onSnap &cb) { +#if defined(ENABLE_FFMPEG) + if (async) { + makeSnapAsync(play_url, save_path, timeout_sec, cb); + return; + } +#endif + GET_CONFIG(string, ffmpeg_bin, FFmpeg::kBin); GET_CONFIG(string, ffmpeg_snap, FFmpeg::kSnap); GET_CONFIG(string, ffmpeg_log, FFmpeg::kLog); diff --git a/server/FFmpegSource.h b/server/FFmpegSource.h index 4bc44e33..c11a7426 100644 --- a/server/FFmpegSource.h +++ b/server/FFmpegSource.h @@ -26,17 +26,20 @@ namespace FFmpeg { class FFmpegSnap { public: using onSnap = std::function; - // / 创建截图 [AUTO-TRANSLATED:6d334c49] - // / Create a screenshot - // / \param play_url 播放url地址,只要FFmpeg支持即可 [AUTO-TRANSLATED:609d4de4] - // / \param play_url The playback URL address, as long as FFmpeg supports it - // / \param save_path 截图jpeg文件保存路径 [AUTO-TRANSLATED:0fc0ac0d] - // / \param save_path The path to save the screenshot JPEG file - // / \param timeout_sec 生成截图超时时间(防止阻塞太久) [AUTO-TRANSLATED:0dcc0095] - // / \param timeout_sec Timeout for generating the screenshot (to prevent blocking for too long) - // / \param cb 生成截图成功与否回调 [AUTO-TRANSLATED:5b4b93c9] - // / \param cb Callback for whether the screenshot was generated successfully - static void makeSnap(const std::string &play_url, const std::string &save_path, float timeout_sec, const onSnap &cb); + /** + * 创建截图 [AUTO-TRANSLATED:6d334c49] + * Create a screenshot + * @param async 是否使用异步截图方式(非ffmpeg命令行,而是使用zlm api,但是仅限于zlm播放器支持的拉流协议) + * @param play_url 播放url地址,只要FFmpeg支持即可 [AUTO-TRANSLATED:609d4de4] + * @param play_url The playback URL address, as long as FFmpeg supports it + * @param save_path 截图jpeg文件保存路径 [AUTO-TRANSLATED:0fc0ac0d] + * @param save_path The path to save the screenshot JPEG file + * @param timeout_sec 生成截图超时时间(防止阻塞太久) [AUTO-TRANSLATED:0dcc0095] + * @param timeout_sec Timeout for generating the screenshot (to prevent blocking for too long) + * @param cb 生成截图成功与否回调 [AUTO-TRANSLATED:5b4b93c9] + * @param cb Callback for whether the screenshot was generated successfully + */ + static void makeSnap(bool async, const std::string &play_url, const std::string &save_path, float timeout_sec, const onSnap &cb); private: FFmpegSnap() = delete; diff --git a/server/Process.cpp b/server/Process.cpp index 120ba0ab..aba6b612 100644 --- a/server/Process.cpp +++ b/server/Process.cpp @@ -10,6 +10,7 @@ #include #include +#include "ShellParser.h" #ifndef _WIN32 #include #include @@ -86,20 +87,22 @@ static int runChildProcess(string cmd, string log_file) { // Close log file. ::fclose(fp); } - fprintf(stderr, "\r\n\r\n#### pid=%d,cmd=%s #####\r\n\r\n", getpid(), cmd.data()); + fprintf(stderr, "\r\n#### pid=%d,cmd=%s #####\r\n", getpid(), cmd.data()); - auto params = split(cmd, " "); - // memory leak in child process, it's ok. - char **charpv_params = new char *[params.size() + 1]; - for (int i = 0; i < (int)params.size(); i++) { - std::string &p = params[i]; - charpv_params[i] = (char *)p.data(); + auto result = parse_shell_like(cmd); + if (!result.ok) { + fprintf(stderr, "parse cmd line failed: %s, pos: %ld", result.error_msg.data(), result.error_pos); + return -1; } - // EOF: NULL - charpv_params[params.size()] = NULL; - // TODO: execv or execvp - auto ret = execv(params[0].c_str(), charpv_params); - delete[] charpv_params; + auto argv = make_argv(result.args); + auto argc = 0u; + fprintf(stderr, "\r\n#### args #####\r\n"); + for (auto &arg : argv) { + fprintf(stderr, "arg[%d]: %s\r\n", argc++, arg ? arg : "null"); + } + + fprintf(stderr, "\r\n#### process log #####\r\n"); + auto ret = execv(argv[0], (char * const *)(argv.data())); if (ret < 0) { fprintf(stderr, "execv process failed:%d(%s)\r\n", get_uv_error(), get_uv_errmsg()); diff --git a/server/ShellParser.h b/server/ShellParser.h new file mode 100644 index 00000000..9a9330a0 --- /dev/null +++ b/server/ShellParser.h @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. +*/ + +#ifndef ZLMEDIAKIT_SHELLPARSER_H +#define ZLMEDIAKIT_SHELLPARSER_H + +#include +#include +#include +#include + +// Shell-like command line parser. +// Features: +// - Whitespace splitting (space, tab, newline) +// - Quotes: single ('...') and double ("...") +// - Escapes with backslash (\\) outside quotes +// - In single quotes: backslash is literal (like POSIX shell) +// - In double quotes: backslash can escape " $ ` \\ and newline (line continuation) +// Additionally supports common C-style escapes: \n \t \r \0 .. outside and inside double quotes +// - Line continuation: backslash followed by newline is ignored +// - Produces argv pointers with stable lifetime backed by std::vector +// +// Notes: +// - This is NOT a full shell (no variable expansion, no globbing, no command substitution). +// - Behavior aims to be practical and safe for exec* arguments building. + +struct ParseResult { + ParseResult(bool ok, const char *err, size_t pos, std::vector args) + : ok(ok) + , error_msg(err) + , error_pos(pos) + , args(std::move(args)) {} + + bool ok; + std::string error_msg; + size_t error_pos = 0; // index in input when error happens + std::vector args; // parsed arguments +}; + +namespace detail { + +inline bool is_space(char c) { + return c == ' ' || c == '\t' || c == '\n'; +} + +// Returns true if it handled a line continuation ("\\\n"). +inline bool handle_line_continuation(const std::string &s, size_t &i) { + if (i + 1 < s.size() && s[i] == '\\' && s[i + 1] == '\n') { + i += 2; // consume both and do nothing + return true; + } + return false; +} + +inline bool hex_digit(char c) { return std::isxdigit(static_cast(c)) != 0; } +inline int hex_val(char c) { + if (c >= '0' && c <= '9') return c - '0'; + if (c >= 'a' && c <= 'f') return 10 + (c - 'a'); + if (c >= 'A' && c <= 'F') return 10 + (c - 'A'); + return 0; +} + +// Parse C-style escapes: \n, \t, \r, \0..\377 (octal), \xHH (hex). Returns std::nullopt if not a known escape. +inline std::pair c_style_escape(const std::string &s, size_t &i) { + if (i >= s.size()) return std::make_pair(false, '\0'); + char c = s[i]; + switch (c) { + case 'n': ++i; return std::make_pair(true, '\n'); + case 't': ++i; return std::make_pair(true, '\t'); + case 'r': ++i; return std::make_pair(true, '\r'); + case 'a': ++i; return std::make_pair(true, '\a'); + case 'b': ++i; return std::make_pair(true, '\b'); + case 'f': ++i; return std::make_pair(true, '\f'); + case 'v': ++i; return std::make_pair(true, '\v'); + case '\\': ++i; return std::make_pair(true, '\\'); + case '"': ++i; return std::make_pair(true, '"'); + case '\'': ++i; return std::make_pair(true, '\''); + case '0': { + // up to 3 octal digits total (including the first 0 already consumed here?) + // Here c=='0' means octal sequence starts at current '0'. + // We'll parse up to 3 octal digits starting at current pos. + int val = 0; int cnt = 0; + while (i < s.size() && cnt < 3 && (s[i] >= '0' && s[i] <= '7')) { + val = (val << 3) + (s[i] - '0'); + ++i; ++cnt; + } + return std::make_pair(true, static_cast(val & 0xFF)); + } + case 'x': { + ++i; // consume 'x' + int val = 0; int cnt = 0; + while (i < s.size() && cnt < 2 && hex_digit(s[i])) { + val = (val << 4) + hex_val(s[i]); + ++i; ++cnt; + } + if (cnt == 0) return std::make_pair(false, '\0'); // not actually a hex escape + return std::make_pair(true, static_cast(val & 0xFF)); + } + default: + return std::make_pair(false, '\0'); + } +} + +} + +ParseResult parse_shell_like(const std::string &input) { + using namespace detail; + std::vector args; + std::string cur; + + enum class State { Normal, InSingle, InDouble }; + State st = State::Normal; + + size_t i = 0; const size_t N = input.size(); + while (i < N) { + // line continuation check (\\\n) applies in all states + if (handle_line_continuation(input, i)) continue; + if (i >= N) break; + + char c = input[i]; + switch (st) { + case State::Normal: { + if (is_space(c)) { + if (!cur.empty()) { args.emplace_back(std::move(cur)); cur.clear(); } + ++i; + } else if (c == '\'') { + st = State::InSingle; ++i; + } else if (c == '"') { + st = State::InDouble; ++i; + } else if (c == '\\') { + ++i; // consume backslash + if (i >= N) { + return {false, "结尾处孤立的反斜杠(未转义任何字符)", i, {}}; + } + // Try C-style escapes first + auto esc = c_style_escape(input, i); + if (esc.first) { + cur.push_back(esc.second); + } else { + // Not a known C escape: take the next char literally + cur.push_back(input[i]); + ++i; + } + } else { + cur.push_back(c); ++i; + } + } break; + + case State::InSingle: { + if (c == '\'') { st = State::Normal; ++i; } + else { cur.push_back(c); ++i; } + } break; + + case State::InDouble: { + if (c == '"') { st = State::Normal; ++i; } + else if (c == '\\') { + ++i; // consume backslash + if (i >= N) { + return {false, "双引号内以反斜杠结尾,缺少被转义字符", i, {}}; + } + // In POSIX shell, within double quotes, only certain escapes are special. + // Here we support both POSIX subset and common C-style escapes for practicality. + auto esc = c_style_escape(input, i); + if (esc.first) { + cur.push_back(esc.second); + } else { + // If not a C-style escape, allow escaping one char literally (e.g., $ `) + cur.push_back(input[i]); + ++i; + } + } else { + cur.push_back(c); ++i; + } + } break; + } + } + + if (st == State::InSingle) { + return {false, "缺少配对的单引号(')", i, {}}; + } + if (st == State::InDouble) { + return {false, "缺少配对的双引号(\")", i, {}}; + } + + if (!cur.empty()) args.emplace_back(std::move(cur)); + + return {true, "", 0, std::move(args)}; +} + +// Helper: build argv pointers backed by the strings' storage. +// The returned vector includes a trailing nullptr, suitable for execv*. +inline std::vector make_argv(const std::vector& args) { + std::vector argv; + argv.reserve(args.size() + 1); + for (const auto &s : args) argv.push_back(s.c_str()); + argv.push_back(nullptr); + return argv; +} + +#endif // ZLMEDIAKIT_SHELLPARSER_H diff --git a/server/System.cpp b/server/System.cpp index c37e4958..c76af03e 100644 --- a/server/System.cpp +++ b/server/System.cpp @@ -15,6 +15,12 @@ #if !defined(ANDROID) #include #endif//!defined(ANDROID) +#else +#include +#include +#include +#include +#pragma comment(lib, "DbgHelp.lib") #endif//!defined(_WIN32) #include @@ -213,6 +219,48 @@ void System::systemSetup(){ // Ignore the hang up signal signal(SIGHUP, SIG_IGN); #endif// ANDROID +#else + // 避免系统弹窗导致程序阻塞,适合无界面或后台服务场景。 + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); + +#if !defined(__MINGW32__) + // 将assert和error时错误输出 + _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_DEBUG); + _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_DEBUG); +#endif + + _setmode(0, _O_BINARY); + _setmode(1, _O_BINARY); + _setmode(2, _O_BINARY); + + setvbuf(stdout, NULL, _IONBF, 0); + setvbuf(stderr, NULL, _IONBF, 0); + std::ios_base::sync_with_stdio(false); + + // 注册crash自动生成dump(等价core dump) + SetUnhandledExceptionFilter([](EXCEPTION_POINTERS *pException) -> LONG { + // 生成 dump 文件名,带时间戳 + char dumpPath[MAX_PATH]; + std::time_t t = std::time(nullptr); + std::tm tm; +#ifdef _MSC_VER + localtime_s(&tm, &t); +#else + tm = *std::localtime(&t); +#endif + std::strftime(dumpPath, sizeof(dumpPath), "crash_%Y%m%d_%H%M%S.dmp", &tm); + + HANDLE hFile = CreateFileA(dumpPath, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr); + if (hFile != INVALID_HANDLE_VALUE) { + MINIDUMP_EXCEPTION_INFORMATION mdei; + mdei.ThreadId = GetCurrentThreadId(); + mdei.ExceptionPointers = pException; + mdei.ClientPointers = FALSE; + MiniDumpWriteDump(GetCurrentProcess(), GetCurrentProcessId(), hFile, MiniDumpNormal, &mdei, nullptr, nullptr); + CloseHandle(hFile); + } + return EXCEPTION_EXECUTE_HANDLER; + }); #endif//!defined(_WIN32) } diff --git a/server/VideoStack.cpp b/server/VideoStack.cpp index 6ae4a005..b251405b 100644 --- a/server/VideoStack.cpp +++ b/server/VideoStack.cpp @@ -21,6 +21,14 @@ #define RGB_TO_U(R, G, B) (((-26 * (R) - 87 * (G) + 112 * (B) + 128) >> 8) + 128) #define RGB_TO_V(R, G, B) (((112 * (R) - 102 * (G) - 10 * (B) + 128) >> 8) + 128) +static void fill_yuv_func(const mediakit::FFmpegFrame::Ptr &frame, int y, int u, int v) { + const auto& yuv = frame->get(); + memset(yuv->data[0], y, yuv->linesize[0] * yuv->height); + memset(yuv->data[1], u, yuv->linesize[1] * ((yuv->height + 1) / 2)); + memset(yuv->data[2], v, yuv->linesize[2] * ((yuv->height + 1) / 2)); +} + + INSTANCE_IMP(VideoStackManager) Param::~Param() { @@ -31,6 +39,13 @@ Param::~Param() { Channel::Channel(const std::string& id, int width, int height, AVPixelFormat pixfmt) : _id(id), _width(width), _height(height), _pixfmt(pixfmt) { +#if defined(VIDEOSTACK_KEEP_ASPECT_RATIO) + _keepAspectRatio = true; +#else + _keepAspectRatio = false; +#endif + _lastWidht = 0; + _lastHeight = 0; _tmp = std::make_shared(); _tmp->get()->width = _width; @@ -39,14 +54,9 @@ Channel::Channel(const std::string& id, int width, int height, AVPixelFormat pix av_frame_get_buffer(_tmp->get(), 32); - memset(_tmp->get()->data[0], 0, _tmp->get()->linesize[0] * _height); - memset(_tmp->get()->data[1], 0, _tmp->get()->linesize[1] * _height / 2); - memset(_tmp->get()->data[2], 0, _tmp->get()->linesize[2] * _height / 2); - auto frame = VideoStackManager::Instance().getBgImg(); - _sws = std::make_shared(_pixfmt, _width, _height); - _tmp = _sws->inputFrame(frame); + resizeFrame(frame); } void Channel::addParam(const std::weak_ptr& p) { @@ -60,8 +70,7 @@ void Channel::onFrame(const mediakit::FFmpegFrame::Ptr& frame) { _poller->async([weakSelf, frame]() { auto self = weakSelf.lock(); if (!self) { return; } - self->_tmp = self->_sws->inputFrame(frame); - + self->resizeFrame(frame); self->forEachParam([self](const Param::Ptr& p) { self->fillBuffer(p); }); }); } @@ -110,6 +119,78 @@ void Channel::copyData(const mediakit::FFmpegFrame::Ptr& buf, const Param::Ptr& default: WarnL << "No support pixformat: " << av_get_pix_fmt_name(p->pixfmt); break; } } + +void Channel::resizeFrame(const mediakit::FFmpegFrame::Ptr &frame) { + if (_keepAspectRatio) { + resizeFrameImplWithAspectRatio(frame); + } else { + resizeFrameImplWithoutAspectRatio(frame); + } +} + +void Channel::resizeFrameImplWithAspectRatio(const mediakit::FFmpegFrame::Ptr &frame) { + int srcWidth = frame->get()->width; + int srcHeight = frame->get()->height; + if (srcWidth <= 0 || srcHeight <= 0) { + return; + } + + // 当新frame宽高变化时,重新初始化sws + if (srcWidth != _lastWidht || srcHeight != _lastHeight) { + _lastWidht = srcWidth; + _lastHeight = srcHeight; + fill_yuv_func(_tmp, 16, 128, 128); + + int dstWidth = _width; + int dstHeight = _height; + + float srcAspectRatio = static_cast(srcWidth) / srcHeight; + float dstAspectRatio = static_cast(dstWidth) / dstHeight; + + int scaledWidth, scaledHeight; + if (srcAspectRatio > dstAspectRatio) { + scaledWidth = dstWidth; + scaledHeight = static_cast(dstWidth / srcAspectRatio); + } else { + scaledHeight = dstHeight; + scaledWidth = static_cast(dstHeight * srcAspectRatio); + } + + _offsetX = (dstWidth - scaledWidth) / 2; + _offsetY = (dstHeight - scaledHeight) / 2; + _sws = std::make_shared(_pixfmt, scaledWidth, scaledHeight); + } + + auto scaledFrame = _sws->inputFrame(frame); + + int copyWidth = ((_width) < (scaledFrame->get()->width) ? (_width) : (scaledFrame->get()->width)); + int copyHeight = ((_height) < (scaledFrame->get()->height) ? (_height) : (scaledFrame->get()->height)); + + for (int i = 0; i < copyHeight; i++) { + memcpy( + _tmp->get()->data[0] + (i + _offsetY) * _tmp->get()->linesize[0] + _offsetX, scaledFrame->get()->data[0] + i * scaledFrame->get()->linesize[0], + copyWidth); + } + + for (int i = 0; i < (copyHeight + 1) / 2; i++) { + memcpy( + _tmp->get()->data[1] + (i + _offsetY / 2) * _tmp->get()->linesize[1] + _offsetX / 2, + scaledFrame->get()->data[1] + i * scaledFrame->get()->linesize[1], copyWidth / 2); + memcpy( + _tmp->get()->data[2] + (i + _offsetY / 2) * _tmp->get()->linesize[2] + _offsetX / 2, + scaledFrame->get()->data[2] + i * scaledFrame->get()->linesize[2], copyWidth / 2); + } + +} + +void Channel::resizeFrameImplWithoutAspectRatio(const mediakit::FFmpegFrame::Ptr &frame) { + if (!_sws) { + fill_yuv_func(_tmp, 16, 128, 128); + _sws = std::make_shared(_pixfmt, _width, _height); + } + _tmp = _sws->inputFrame(frame); +} + void StackPlayer::addChannel(const std::weak_ptr& chn) { std::lock_guard lock(_mx); _channels.push_back(chn); @@ -151,10 +232,10 @@ void StackPlayer::play() { // auto audioTrack = std::dynamic_pointer_cast(strongPlayer->getTrack(mediakit::TrackAudio, false)); if (videoTrack) { + // 如果每次不同 可以加个时间戳 time(NULL); // TODO:添加使用显卡还是cpu解码的判断逻辑 [AUTO-TRANSLATED:44bef37a] // TODO: Add logic to determine whether to use GPU or CPU decoding - auto decoder = std::make_shared( - videoTrack, 0, std::vector{"h264", "hevc"}); + auto decoder = std::make_shared(videoTrack, 0, std::vector { "h264", "hevc" }); decoder->setOnDecode([weakSelf](const mediakit::FFmpegFrame::Ptr& frame) mutable { auto self = weakSelf.lock(); @@ -300,9 +381,7 @@ void VideoStack::initBgColor() { double U = RGB_TO_U(R, G, B); double V = RGB_TO_V(R, G, B); - memset(_buffer->get()->data[0], Y, _buffer->get()->linesize[0] * _height); - memset(_buffer->get()->data[1], U, _buffer->get()->linesize[1] * _height / 2); - memset(_buffer->get()->data[2], V, _buffer->get()->linesize[2] * _height / 2); + fill_yuv_func(_buffer, Y, U, V); } Channel::Ptr VideoStackManager::getChannel(const std::string& id, int width, int height, diff --git a/server/VideoStack.h b/server/VideoStack.h index 1f9c114b..02c1604f 100644 --- a/server/VideoStack.h +++ b/server/VideoStack.h @@ -62,12 +62,24 @@ protected: void copyData(const mediakit::FFmpegFrame::Ptr& buf, const Param::Ptr& p); + void resizeFrame(const mediakit::FFmpegFrame::Ptr &frame); + + void resizeFrameImplWithAspectRatio(const mediakit::FFmpegFrame::Ptr &frame); + + void resizeFrameImplWithoutAspectRatio(const mediakit::FFmpegFrame::Ptr &frame); + private: std::string _id; int _width; int _height; AVPixelFormat _pixfmt; + int _lastWidht; + int _lastHeight; + bool _keepAspectRatio; + int _offsetX; + int _offsetY; + mediakit::FFmpegFrame::Ptr _tmp; std::recursive_mutex _mx; diff --git a/server/WebApi.cpp b/server/WebApi.cpp index f5071027..80e68b50 100755 --- a/server/WebApi.cpp +++ b/server/WebApi.cpp @@ -47,6 +47,7 @@ #include "Player/PlayerProxy.h" #include "Pusher/PusherProxy.h" #include "Rtp/RtpProcess.h" +#include "Rtp/RtpSender.h" #include "Record/MP4Reader.h" #if defined(ENABLE_RTPPROXY) @@ -57,6 +58,10 @@ #include "../webrtc/WebRtcPlayer.h" #include "../webrtc/WebRtcPusher.h" #include "../webrtc/WebRtcEchoTest.h" +#include "../webrtc/WebRtcSignalingPeer.h" +#include "../webrtc/WebRtcSignalingSession.h" +#include "../webrtc/WebRtcProxyPlayer.h" +#include "../webrtc/WebRtcProxyPlayerImp.h" #endif #if defined(ENABLE_VERSION) @@ -67,6 +72,9 @@ #include "VideoStack.h" #endif +#include "Onvif/Onvif.h" +#include "Onvif/SoapUtil.h" + using namespace std; using namespace Json; using namespace toolkit; @@ -94,21 +102,21 @@ using HttpApi = function s_map_api; -static void responseApi(const Json::Value &res, const HttpSession::HttpResponseInvoker &invoker){ - GET_CONFIG(string, charSet, Http::kCharSet); - HttpSession::KeyValue headerOut; - headerOut["Content-Type"] = string("application/json; charset=") + charSet; - invoker(200, headerOut, res.toStyledString()); -}; - -static void responseApi(int code, const string &msg, const HttpSession::HttpResponseInvoker &invoker){ +static void responseApi(int code, const string &msg, const HttpSession::HttpResponseInvoker &invoker, ApiRetException *ex = nullptr){ Json::Value res; + HttpSession::KeyValue headerOut; + if (ex) { + res = ex->getBody(); + headerOut = ex->getHeaders(); + } res["code"] = code; res["msg"] = msg; - responseApi(res, invoker); -} -static ApiArgsType getAllArgs(const Parser &parser); + GET_CONFIG(string, charSet, Http::kCharSet); + headerOut["Content-Type"] = string("application/json; charset=") + charSet; + + invoker(200, headerOut, res.toStyledString()); +} static HttpApi toApi(const function &cb) { return [cb](const Parser &parser, const HttpSession::HttpResponseInvoker &invoker, SockInfo &sender) { @@ -208,7 +216,7 @@ void api_regist(const string &api_path, const function> jsonArgs; - auto keys = jsonArgs.getMemberNames(); - for (auto key = keys.begin(); key != keys.end(); ++key) { - allArgs[*key] = jsonArgs[*key].asString(); + if (!parser.content().empty()) { + try { + stringstream ss(parser.content()); + Value jsonArgs; + ss >> jsonArgs; + auto keys = jsonArgs.getMemberNames(); + for (auto key = keys.begin(); key != keys.end(); ++key) { + allArgs[*key] = jsonArgs[*key].asString(); + } + } catch (std::exception &ex) { + WarnL << ex.what(); } - } catch (std::exception &ex) { - WarnL << ex.what(); } } else if (!parser["Content-Type"].empty()) { WarnL << "invalid Content-Type:" << parser["Content-Type"]; @@ -293,92 +303,27 @@ static inline void addHttpListener(){ }; ((HttpSession::HttpResponseInvoker &) invoker) = newInvoker; } - - try { - it->second(parser, invoker, sender); - } catch (ApiRetException &ex) { - responseApi(ex.code(), ex.what(), invoker); - auto helper = static_cast(sender).shared_from_this(); - helper->getPoller()->async([helper, ex]() { helper->shutdown(SockException(Err_shutdown, ex.what())); }, false); - } + auto helper = static_cast(sender).shared_from_this(); + // 在本poller线程下一次事件循环时执行http api,防止占用NoticeCenter的锁 + helper->getPoller()->async([it, parser, invoker, helper]() { + try { + it->second(parser, invoker, *helper); + } catch (ApiRetException &ex) { + responseApi(ex.code(), ex.what(), invoker, &ex); + helper->getPoller()->async([helper, ex]() { helper->shutdown(SockException(Err_shutdown, ex.what())); }, false); + } #ifdef ENABLE_MYSQL - catch(SqlException &ex){ - responseApi(API::SqlFailed, StrPrinter << "操作数据库失败:" << ex.what() << ":" << ex.getSql(), invoker); - } -#endif// ENABLE_MYSQL - catch (std::exception &ex) { - responseApi(API::Exception, ex.what(), invoker); - } + catch (SqlException &ex) { + responseApi(API::SqlFailed, StrPrinter << "操作数据库失败:" << ex.what() << ":" << ex.getSql(), invoker, &ex); + } +#endif // ENABLE_MYSQL + catch (std::exception &ex) { + responseApi(API::Exception, ex.what(), invoker); + } + },false); }); } -template -class ServiceController { -public: - using Pointer = std::shared_ptr; - std::unordered_map _map; - mutable std::recursive_mutex _mtx; - - void clear() { - decltype(_map) copy; - { - std::lock_guard lck(_mtx); - copy.swap(_map); - } - } - - size_t erase(const std::string &key) { - std::lock_guard lck(_mtx); - return _map.erase(key); - } - - size_t size() { - std::lock_guard lck(_mtx); - return _map.size(); - } - - Pointer find(const std::string &key) const { - std::lock_guard lck(_mtx); - auto it = _map.find(key); - if (it == _map.end()) { - return nullptr; - } - return it->second; - } - - void for_each(const std::function& cb) { - std::lock_guard lck(_mtx); - auto it = _map.begin(); - while (it != _map.end()) { - cb(it->first, it->second); - it++; - } - } - - template - Pointer make(const std::string &key, _Args&& ...__args) { - // assert(!find(key)); - - auto server = std::make_shared(std::forward<_Args>(__args)...); - std::lock_guard lck(_mtx); - auto it = _map.emplace(key, server); - assert(it.second); - return server; - } - - template - Pointer makeWithAction(const std::string &key, function action, _Args&& ...__args) { - // assert(!find(key)); - - auto server = std::make_shared(std::forward<_Args>(__args)...); - action(server); - std::lock_guard lck(_mtx); - auto it = _map.emplace(key, server); - assert(it.second); - return server; - } -}; - // 拉流代理器列表 [AUTO-TRANSLATED:6dcfb11f] // Pull stream proxy list static ServiceController s_player_proxy; @@ -403,7 +348,7 @@ static inline string getPusherKey(const string &schema, const string &vhost, con return schema + "/" + vhost + "/" + app + "/" + stream + "/" + MD5(dst_url).hexdigest(); } -static void fillSockInfo(Value& val, SockInfo* info) { +void fillSockInfo(Value& val, SockInfo* info) { val["peer_ip"] = info->get_peer_ip(); val["peer_port"] = info->get_peer_port(); val["local_port"] = info->get_local_port(); @@ -424,30 +369,82 @@ Value ToJson(const PusherProxy::Ptr& p) { item["status"] = p->getStatus(); item["liveSecs"] = p->getLiveSecs(); item["rePublishCount"] = p->getRePublishCount(); + item["bytesSpeed"] = (Json::UInt64) p->getSendSpeed(); + item["totalBytes"] =(Json::UInt64) p->getSendTotalBytes(); + if (auto src = p->getSrc()) { dumpMediaTuple(src->getMediaTuple(), item["src"]); } return item; } +Json::Value dumpTracks(const std::vector &tracks) { + Json::Value ret(arrayValue); + for (auto &track : tracks) { + Value obj; + auto codec_type = track->getTrackType(); + obj["codec_id"] = track->getCodecId(); + obj["codec_id_name"] = track->getCodecName(); + obj["ready"] = track->ready(); + obj["codec_type"] = codec_type; + obj["frames"] = track->getFrames(); + obj["duration"] = track->getDuration(); + switch (codec_type) { + case TrackAudio: { + auto audio_track = dynamic_pointer_cast(track); + obj["sample_rate"] = audio_track->getAudioSampleRate(); + obj["channels"] = audio_track->getAudioChannel(); + obj["sample_bit"] = audio_track->getAudioSampleBit(); + break; + } + case TrackVideo: { + auto video_track = dynamic_pointer_cast(track); + obj["width"] = video_track->getVideoWidth(); + obj["height"] = video_track->getVideoHeight(); + obj["key_frames"] = video_track->getVideoKeyFrames(); + int gop_size = video_track->getVideoGopSize(); + int gop_interval_ms = video_track->getVideoGopInterval(); + float fps = video_track->getVideoFps(); + if (fps <= 1 && gop_interval_ms) { + fps = gop_size * 1000.0 / gop_interval_ms; + } + obj["fps"] = round(fps); + obj["gop_size"] = gop_size; + obj["gop_interval_ms"] = gop_interval_ms; + break; + } + default: break; + } + ret.append(obj); + } + return ret; +} + Value ToJson(const PlayerProxy::Ptr& p) { Value item; item["url"] = p->getUrl(); item["status"] = p->getStatus(); + item["status_str"] = p->getStatusStr(); item["liveSecs"] = p->getLiveSecs(); item["rePullCount"] = p->getRePullCount(); item["totalReaderCount"] = p->totalReaderCount(); + item["bytesSpeed"] = (Json::UInt64) p->getRecvSpeed(); + item["totalBytes"] = (Json::UInt64) p->getRecvTotalBytes(); + dumpMediaTuple(p->getMediaTuple(), item["src"]); + item["tracks"] = dumpTracks(p->getTracks(false)); return item; } -Value makeMediaSourceJson(MediaSource &media){ +Value makeMediaSourceJson(MediaSource &media) { Value item; item["schema"] = media.getSchema(); dumpMediaTuple(media.getMediaTuple(), item); item["createStamp"] = (Json::UInt64) media.getCreateStamp(); + item["currentStamp"] = (Json::UInt64) media.getTimeStamp(TrackInvalid); item["aliveSecond"] = (Json::UInt64) media.getAliveSecond(); - item["bytesSpeed"] = media.getBytesSpeed(); + item["bytesSpeed"] = (Json::UInt64) media.getBytesSpeed(); + item["totalBytes"] = (Json::UInt64) media.getTotalBytes(); item["readerCount"] = media.readerCount(); item["totalReaderCount"] = media.totalReaderCount(); item["originType"] = (int) media.getOriginType(); @@ -467,17 +464,13 @@ Value makeMediaSourceJson(MediaSource &media){ auto current_thread = false; try { current_thread = media.getOwnerPoller()->isCurrentThread();} catch (...) {} float last_loss = -1; - for(auto &track : media.getTracks(false)){ - Value obj; - auto codec_type = track->getTrackType(); - obj["codec_id"] = track->getCodecId(); - obj["codec_id_name"] = track->getCodecName(); - obj["ready"] = track->ready(); - obj["codec_type"] = codec_type; - if (current_thread) { + auto tracks = dumpTracks(media.getTracks(false)); + if (current_thread) { + for (auto &obj : tracks) { // rtp推流只有一个统计器,但是可能有多个track,如果短时间多次获取间隔丢包率,第二次会获取为-1 [AUTO-TRANSLATED:5bfbc951] - // RTP push stream has only one statistics, but may have multiple tracks. If you get the interval packet loss rate multiple times in a short time, the second time will get -1 - auto loss = media.getLossRate(codec_type); + // RTP push stream has only one statistics, but may have multiple tracks. If you get the interval packet loss rate multiple times in a short time, + // the second time will get -1 + auto loss = media.getLossRate(getTrackType(static_cast(obj["codec_type"].asInt()))); if (loss == -1) { loss = last_loss; } else { @@ -485,37 +478,8 @@ Value makeMediaSourceJson(MediaSource &media){ } obj["loss"] = loss; } - obj["frames"] = track->getFrames(); - obj["duration"] = track->getDuration(); - switch(codec_type){ - case TrackAudio : { - auto audio_track = dynamic_pointer_cast(track); - obj["sample_rate"] = audio_track->getAudioSampleRate(); - obj["channels"] = audio_track->getAudioChannel(); - obj["sample_bit"] = audio_track->getAudioSampleBit(); - break; - } - case TrackVideo : { - auto video_track = dynamic_pointer_cast(track); - obj["width"] = video_track->getVideoWidth(); - obj["height"] = video_track->getVideoHeight(); - obj["key_frames"] = video_track->getVideoKeyFrames(); - int gop_size = video_track->getVideoGopSize(); - int gop_interval_ms = video_track->getVideoGopInterval(); - float fps = video_track->getVideoFps(); - if (fps <= 1 && gop_interval_ms) { - fps = gop_size * 1000.0 / gop_interval_ms; - } - obj["fps"] = round(fps); - obj["gop_size"] = gop_size; - obj["gop_interval_ms"] = gop_interval_ms; - break; - } - default: - break; - } - item["tracks"].append(obj); } + item["tracks"] = std::move(tracks); return item; } @@ -629,8 +593,19 @@ void getStatisticJson(const function &cb) { #endif } -void addStreamProxy(const MediaTuple &tuple, const string &url, int retry_count, - const ProtocolOption &option, int rtp_type, float timeout_sec, const mINI &args, +void updateStreamProxy(const mediakit::MediaTuple &tuple, const std::string &url, const toolkit::mINI &args) { + auto key = tuple.shortUrl(); + auto player = s_player_proxy.find(key); + if (!player) { + throw std::runtime_error("proxy player not found: " + key); + } + player->getPoller()->async([url, args, player]() { + player->update(url, args); + }); +} + +void addStreamProxy(const MediaTuple &tuple, const string &url, int retry_count, bool force, + const ProtocolOption &option, float timeout_sec, const mINI &args, const function &cb) { auto key = tuple.shortUrl(); if (s_player_proxy.find(key)) { @@ -649,10 +624,6 @@ void addStreamProxy(const MediaTuple &tuple, const string &url, int retry_count, (*player)[pr.first] = pr.second; } - // 指定RTP over TCP(播放rtsp时有效) [AUTO-TRANSLATED:1a062656] - // Specify RTP over TCP (effective when playing RTSP) - (*player)[Client::kRtpType] = rtp_type; - if (timeout_sec > 0.1f) { // 播放握手超时时间 [AUTO-TRANSLATED:5a29ae1f] // Play handshake timeout @@ -661,11 +632,18 @@ void addStreamProxy(const MediaTuple &tuple, const string &url, int retry_count, // 开始播放,如果播放失败或者播放中止,将会自动重试若干次,默认一直重试 [AUTO-TRANSLATED:ac8499e5] // Start playing. If playback fails or is stopped, it will automatically retry several times, by default it will retry indefinitely - player->setPlayCallbackOnce([cb, key](const SockException &ex) { - if (ex) { - s_player_proxy.erase(key); + player->setPlayCallbackOnce([cb, key, force](const SockException &ex) { + if (force) { + // 强制添加成功 + cb(SockException(), key); + } else { + // 非强制添加 + if (ex) { + // 失败则移除记录 + s_player_proxy.erase(key); + } + cb(ex, key); } - cb(ex, key); }); // 被主动关闭拉流 [AUTO-TRANSLATED:41a19476] @@ -739,6 +717,73 @@ void addStreamPusherProxy(const string &schema, pusher->publish(url); } +void getThreadsLoad(TaskExecutorGetterImp &getter, API_ARGS_MAP_ASYNC) { + getter.getExecutorDelay([&getter, invoker, headerOut](const vector &vecDelay) { + Value val; + auto vec = getter.getExecutorLoad(); + std::vector pollers; + getter.for_each([&](const TaskExecutor::Ptr &exe) { pollers.emplace_back(std::static_pointer_cast(exe)); }); + int i = API::Success; + for (auto load : vec) { + Value obj(objectValue); + obj["load"] = load; + auto &poller = pollers[i]; + obj["name"] = poller->getThreadName(); + obj["fd_count"] = static_cast(poller->fdCount()); + obj["delay"] = vecDelay[i++]; + val["data"].append(obj); + } + val["code"] = API::Success; + invoker(200, headerOut, val.toStyledString()); + }); +} + +static constexpr char kLoginCookiePath[] = "/"; +static constexpr char kUnLoginCookieName[] = "ZLM_UNLOGIN"; +static constexpr char kLoginedCookieName[] = "ZLM_LOGINED"; +static constexpr size_t kUnLoginCookieLifeSeconds = 60; +static constexpr size_t kLoginedCookieLifeSeconds = 24 * 3600; + +template +void check_secret(toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const HttpAllArgs &allArgs, Json::Value &val) { + GET_CONFIG(std::string, api_secret, API::kSecret); + + auto ip = sender.get_peer_ip(); + if (!HttpFileManager::isIPAllowed(ip)) { + throw AuthException("Your ip is not allowed to access the service."); + } + + try { + auto logined_cookie = HttpCookieManager::Instance().getCookie(kLoginedCookieName, allArgs.getParser().getHeader()); + if (!logined_cookie) { + auto unlogin_cookie = HttpCookieManager::Instance().getCookie(kUnLoginCookieName, allArgs.getParser().getHeader()); + if (!unlogin_cookie) { + unlogin_cookie = HttpCookieManager::Instance().addCookie(kUnLoginCookieName, "", kUnLoginCookieLifeSeconds); + headerOut["Set-Cookie"] = unlogin_cookie->getCookie(kLoginCookiePath); + } + val["cookie"] = unlogin_cookie->getCookie(); + throw AuthException("Please login first", headerOut, val); + } + // 优先cookie登陆鉴权 + } catch (...) { + try { + // cookie登陆鉴权失败了再比对secret + CHECK_ARGS("secret"); + if (api_secret != allArgs["secret"]) { + throw AuthException("Incorrect secret"); + } + return; + } catch (...) { + // 未提供secret或secret不匹配,这个异常隐藏 + } + // secret鉴权模式失败,抛出要求cookie登录的异常 + throw; + } +} + +template void check_secret(toolkit::SockInfo &, mediakit::HttpSession::KeyValue &, const HttpAllArgs &, Json::Value &); +template void check_secret(toolkit::SockInfo &, mediakit::HttpSession::KeyValue &, const HttpAllArgs &, Json::Value &); +template void check_secret(toolkit::SockInfo &, mediakit::HttpSession::KeyValue &, const HttpAllArgs &, Json::Value &); /** * 安装api接口 @@ -747,12 +792,11 @@ void addStreamPusherProxy(const string &schema, * Install api interface * All apis support GET and POST methods * POST method parameters support application/json and application/x-www-form-urlencoded methods - + * [AUTO-TRANSLATED:62e68c43] */ void installWebApi() { addHttpListener(); - GET_CONFIG(string,api_secret,API::kSecret); // 获取线程负载 [AUTO-TRANSLATED:3b0ece5c] // Get thread load @@ -760,19 +804,7 @@ void installWebApi() { // Test url http://127.0.0.1/index/api/getThreadsLoad api_regist("/index/api/getThreadsLoad", [](API_ARGS_MAP_ASYNC) { CHECK_SECRET(); - EventPollerPool::Instance().getExecutorDelay([invoker, headerOut](const vector &vecDelay) { - Value val; - auto vec = EventPollerPool::Instance().getExecutorLoad(); - int i = API::Success; - for (auto load : vec) { - Value obj(objectValue); - obj["load"] = load; - obj["delay"] = vecDelay[i++]; - val["data"].append(obj); - } - val["code"] = API::Success; - invoker(200, headerOut, val.toStyledString()); - }); + getThreadsLoad(EventPollerPool::Instance(), API_ARGS_VALUE, invoker); }); // 获取后台工作线程负载 [AUTO-TRANSLATED:6166e265] @@ -781,19 +813,7 @@ void installWebApi() { // Test url http://127.0.0.1/index/api/getWorkThreadsLoad api_regist("/index/api/getWorkThreadsLoad", [](API_ARGS_MAP_ASYNC) { CHECK_SECRET(); - WorkThreadPool::Instance().getExecutorDelay([invoker, headerOut](const vector &vecDelay) { - Value val; - auto vec = WorkThreadPool::Instance().getExecutorLoad(); - int i = 0; - for (auto load : vec) { - Value obj(objectValue); - obj["load"] = load; - obj["delay"] = vecDelay[i++]; - val["data"].append(obj); - } - val["code"] = API::Success; - invoker(200, headerOut, val.toStyledString()); - }); + getThreadsLoad(WorkThreadPool::Instance(), API_ARGS_VALUE, invoker); }); // 获取服务器配置 [AUTO-TRANSLATED:7dd2f3da] @@ -966,13 +986,28 @@ void installWebApi() { // Test url1 (get streams with virtual host "__defaultVost__") http://127.0.0.1/index/api/getMediaList?vhost=__defaultVost__ // 测试url2(获取rtsp类型的流) http://127.0.0.1/index/api/getMediaList?schema=rtsp [AUTO-TRANSLATED:21c2c15d] // Test url2 (get rtsp type streams) http://127.0.0.1/index/api/getMediaList?schema=rtsp - api_regist("/index/api/getMediaList",[](API_ARGS_MAP){ + api_regist("/index/api/getMediaList",[](API_ARGS_MAP_ASYNC){ CHECK_SECRET(); // 获取所有MediaSource列表 [AUTO-TRANSLATED:7bf16dc2] // Get all MediaSource lists + std::list lst; MediaSource::for_each_media([&](const MediaSource::Ptr &media) { - val["data"].append(makeMediaSourceJson(*media)); + lst.emplace_back(media); }, allArgs["schema"], allArgs["vhost"], allArgs["app"], allArgs["stream"]); + + if (lst.size() == 1) { + // 如果是搜索单一流,那么在它的归属线程中执行,用于获取丢包率参数 + auto front = std::move(lst.front()); + front->getOwnerPoller()->async([=]() mutable { + val["data"].append(makeMediaSourceJson(*front)); + invoker(200, headerOut, val.toStyledString()); + }); + } else { + for (auto &media : lst) { + val["data"].append(makeMediaSourceJson(*media)); + } + invoker(200, headerOut, val.toStyledString()); + } }); // 测试url http://127.0.0.1/index/api/isMediaOnline?schema=rtsp&vhost=__defaultVhost__&app=live&stream=obs [AUTO-TRANSLATED:126a75e8] @@ -1007,9 +1042,9 @@ void installWebApi() { }, [](toolkit::Any &&info) -> toolkit::Any { auto obj = std::make_shared(); - auto &sock = info.get(); - fillSockInfo(*obj, &sock); - (*obj)["typeid"] = toolkit::demangle(typeid(sock).name()); + auto &session = info.get(); + fillSockInfo(*obj, &session); + (*obj)["typeid"] = toolkit::demangle(typeid(session).name()); toolkit::Any ret; ret.set(obj); return ret; @@ -1090,9 +1125,8 @@ void installWebApi() { bool force = allArgs["force"].as(); for (auto &media : media_list) { - if (media->close(force)) { - ++count_closed; - } + media->getOwnerPoller()->async([media, force]() { media->close(force); }); + ++count_closed; } val["count_hit"] = count_hit; val["count_closed"] = count_closed; @@ -1119,6 +1153,7 @@ void installWebApi() { } fillSockInfo(jsession, session.get()); jsession["id"] = id; + jsession["type"] = session->getSock()->sockType() == SockNum::Sock_TCP ? "tcp" : "udp"; jsession["typeid"] = toolkit::demangle(typeid(*session).name()); val["data"].append(jsession); }); @@ -1189,25 +1224,27 @@ void installWebApi() { auto dst_url = allArgs["dst_url"]; auto retry_count = allArgs["retry_count"].empty() ? -1 : allArgs["retry_count"].as(); - addStreamPusherProxy(allArgs["schema"], - allArgs["vhost"], - allArgs["app"], - allArgs["stream"], - allArgs["dst_url"], - retry_count, - allArgs["rtp_type"], - allArgs["timeout_sec"], - args, - [invoker, val, headerOut, dst_url](const SockException &ex, const string &key) mutable { - if (ex) { - val["code"] = API::OtherFailed; - val["msg"] = ex.what(); - } else { - val["data"]["key"] = key; - InfoL << "Publish success, please play with player:" << dst_url; - } - invoker(200, headerOut, val.toStyledString()); - }); + EventPollerPool::Instance().getPoller(false)->async([=]() mutable { + addStreamPusherProxy(allArgs["schema"], + allArgs["vhost"], + allArgs["app"], + allArgs["stream"], + allArgs["dst_url"], + retry_count, + allArgs["rtp_type"], + allArgs["timeout_sec"], + args, + [invoker, val, headerOut, dst_url](const SockException &ex, const string &key) mutable { + if (ex) { + val["code"] = API::OtherFailed; + val["msg"] = ex.what(); + } else { + val["data"]["key"] = key; + InfoL << "Publish success, please play with player:" << dst_url; + } + invoker(200, headerOut, val.toStyledString()); + }); + }); }); // 关闭推流代理 [AUTO-TRANSLATED:91602b75] @@ -1221,19 +1258,19 @@ void installWebApi() { }); api_regist("/index/api/listStreamPusherProxy", [](API_ARGS_MAP) { CHECK_SECRET(); - s_pusher_proxy.for_each([&val](const std::string& key, const PusherProxy::Ptr& p) { + s_pusher_proxy.for_each([&val](const std::string &key, const PusherProxy::Ptr &p) { Json::Value item = ToJson(p); item["key"] = key; val["data"].append(item); - }); + }, allArgs["key"]); }); api_regist("/index/api/listStreamProxy", [](API_ARGS_MAP) { CHECK_SECRET(); - s_player_proxy.for_each([&val](const std::string& key, const PlayerProxy::Ptr& p) { + s_player_proxy.for_each([&val](const std::string &key, const PlayerProxy::Ptr &p) { Json::Value item = ToJson(p); item["key"] = key; val["data"].append(item); - }); + }, allArgs["key"]); }); // 动态添加rtsp/rtmp拉流代理 [AUTO-TRANSLATED:2616537c] // Dynamically add rtsp/rtmp pull stream proxy @@ -1256,22 +1293,24 @@ void installWebApi() { vhost = allArgs["vhost"]; } auto tuple = MediaTuple { vhost, allArgs["app"], allArgs["stream"], "" }; - addStreamProxy(tuple, - allArgs["url"], - retry_count, - option, - allArgs["rtp_type"], - allArgs["timeout_sec"], - args, - [invoker,val,headerOut](const SockException &ex,const string &key) mutable{ - if (ex) { - val["code"] = API::OtherFailed; - val["msg"] = ex.what(); - } else { - val["data"]["key"] = key; - } - invoker(200, headerOut, val.toStyledString()); - }); + EventPollerPool::Instance().getPoller(false)->async([=]() mutable { + addStreamProxy(tuple, + allArgs["url"], + retry_count, + allArgs["force"], + option, + allArgs["timeout_sec"], + args, + [invoker,val,headerOut](const SockException &ex,const string &key) mutable { + if (ex) { + val["code"] = API::OtherFailed; + val["msg"] = ex.what(); + } else { + val["data"]["key"] = key; + } + invoker(200, headerOut, val.toStyledString()); + }); + }); }); // 关闭拉流代理 [AUTO-TRANSLATED:5204f128] @@ -1542,20 +1581,18 @@ void installWebApi() { api_regist("/index/api/listRtpServer",[](API_ARGS_MAP){ CHECK_SECRET(); - std::lock_guard lck(s_rtp_server._mtx); - for (auto &pr : s_rtp_server._map) { - auto vec = split(pr.first, "/"); + s_rtp_server.for_each([&val](const std::string &key, const RtpServer::Ptr &rtps) { + auto vec = split(key, "/"); Value obj; obj["vhost"] = vec[0]; obj["app"] = vec[1]; obj["stream_id"] = vec[2]; - auto& rtps = pr.second; obj["port"] = rtps->getPort(); obj["ssrc"] = rtps->getSSRC(); obj["tcp_mode"] = rtps->getTcpMode(); obj["only_track"] = rtps->getOnlyTrack(); val["data"].append(obj); - } + }); }); static auto start_send_rtp = [] (bool passive, API_ARGS_MAP_ASYNC) { @@ -1590,6 +1627,7 @@ void installWebApi() { // Record the app and vhost of the sending stream args.recv_stream_app = allArgs["app"]; args.recv_stream_vhost = allArgs["vhost"]; + args.enable_origin_recv_limit = allArgs["enable_origin_recv_limit"]; src->getOwnerPoller()->async([=]() mutable { try { src->startSendRtp(args, [val, headerOut, invoker](uint16_t local_port, const SockException &ex) mutable { @@ -1636,6 +1674,7 @@ void installWebApi() { args.recv_stream_id = allArgs["recv_stream_id"]; args.recv_stream_app = allArgs["app"]; args.recv_stream_vhost = allArgs["vhost"]; + args.enable_origin_recv_limit = allArgs["enable_origin_recv_limit"]; src->getOwnerPoller()->async([=]() mutable { try { @@ -1668,8 +1707,10 @@ void installWebApi() { CHECK(muxer, "get muxer from media source failed"); src->getOwnerPoller()->async([=]() mutable { - muxer->forEachRtpSender([&](const std::string &ssrc) mutable { + muxer->forEachRtpSender([&](const std::string &ssrc, const RtpSender &sender) mutable { val["data"].append(ssrc); + val["bytesSpeed"] = (Json::UInt64)sender.getSendSpeed(); + val["totalBytes"] = (Json::UInt64)sender.getSendTotalBytes(); }); invoker(200, headerOut, val.toStyledString()); }); @@ -1713,7 +1754,7 @@ void installWebApi() { auto src = MediaSource::find(vhost, app, allArgs["stream_id"]); auto process = src ? src->getRtpProcess() : nullptr; if (process) { - process->setStopCheckRtp(true); + process->pauseRtpTimeout(true, allArgs["pause_seconds"]); } else { val["code"] = API::NotFound; } @@ -1733,7 +1774,7 @@ void installWebApi() { auto src = MediaSource::find(vhost, app, allArgs["stream_id"]); auto process = src ? src->getRtpProcess() : nullptr; if (process) { - process->setStopCheckRtp(false); + process->pauseRtpTimeout(false); } else { val["code"] = API::NotFound; } @@ -1761,6 +1802,30 @@ void installWebApi() { }); }); + api_regist("/index/api/startRecordTask",[](API_ARGS_MAP_ASYNC){ + CHECK_SECRET(); + CHECK_ARGS("vhost", "app", "stream", "path", "back_ms", "forward_ms"); + + auto src = MediaSource::find(allArgs["vhost"], allArgs["app"], allArgs["stream"]); + if (!src) { + throw ApiRetException("can not find the stream", API::NotFound); + } + + src->getOwnerPoller()->async([=]() mutable { + std::string err; + std::string path; + try { + path = src->getMuxer()->startRecord(allArgs["path"], allArgs["back_ms"], allArgs["forward_ms"]); + } catch (std::exception &ex) { + err = ex.what(); + } + val["code"] = err.empty() ? API::Success : API::OtherFailed; + val["data"]["path"] = path; + val["msg"] = err; + invoker(200, headerOut, val.toStyledString()); + }); + }); + // 设置录像流播放速度 [AUTO-TRANSLATED:a8d82298] // Set the playback speed of the recording stream api_regist("/index/api/setRecordSpeed", [](API_ARGS_MAP_ASYNC) { @@ -1873,11 +1938,13 @@ void installWebApi() { // http://127.0.0.1/index/api/deleteRecordDirectroy?vhost=__defaultVhost__&app=live&stream=ss&period=2020-01-01 api_regist("/index/api/deleteRecordDirectory", [](API_ARGS_MAP) { CHECK_SECRET(); - CHECK_ARGS("vhost", "app", "stream", "period"); + CHECK_ARGS("vhost", "app", "stream"); auto tuple = MediaTuple{allArgs["vhost"], allArgs["app"], allArgs["stream"], ""}; auto record_path = Recorder::getRecordPath(Recorder::type_mp4, tuple, allArgs["customized_path"]); auto period = allArgs["period"]; - record_path = record_path + period + "/"; + if (!period.empty()) { + record_path = record_path + period + "/"; + } bool recording = false; auto name = allArgs["name"]; @@ -1912,6 +1979,15 @@ void installWebApi() { File::deleteEmptyDir(record_path); }); + api_regist("/index/api/deleteSnapDirectory", [](API_ARGS_MAP) { + CHECK_SECRET(); + CHECK_ARGS("vhost", "app", "stream"); + GET_CONFIG(std::string, root, API::kSnapRoot); + auto path = File::absolutePath(allArgs["vhost"] + "/" + allArgs["app"] + "/" + allArgs["stream"] + "/" + allArgs["file"], root); + InfoL << "delete " << path; + File::delete_file(path, true); + }); + // 获取录像文件夹列表或mp4文件列表 [AUTO-TRANSLATED:f7e299bc] // Get the list of recording folders or mp4 files //http://127.0.0.1/index/api/getMP4RecordFile?vhost=__defaultVhost__&app=live&stream=ss&period=2020-01 @@ -2047,7 +2123,7 @@ void installWebApi() { // 启动FFmpeg进程,开始截图,生成临时文件,截图成功后替换为正式文件 [AUTO-TRANSLATED:7d589e3f] // Start the FFmpeg process, start taking screenshots, generate temporary files, replace them with formal files after successful screenshots auto new_snap_tmp = new_snap + ".tmp"; - FFmpegSnap::makeSnap(allArgs["url"], new_snap_tmp, allArgs["timeout_sec"], [invoker, allArgs, new_snap, new_snap_tmp](bool success, const string &err_msg) { + FFmpegSnap::makeSnap(allArgs["async"], allArgs["url"], new_snap_tmp, allArgs["timeout_sec"], [invoker, allArgs, new_snap, new_snap_tmp](bool success, const string &err_msg) { if (!success) { // 生成截图失败,可能残留空文件 [AUTO-TRANSLATED:c96a4468] // Screenshot generation failed, there may be residual empty files @@ -2071,35 +2147,6 @@ void installWebApi() { }); #ifdef ENABLE_WEBRTC - class WebRtcArgsImp : public WebRtcArgs { - public: - WebRtcArgsImp(const ArgsString &args, std::string session_id) - : _args(args) - , _session_id(std::move(session_id)) {} - ~WebRtcArgsImp() override = default; - - toolkit::variant operator[](const string &key) const override { - if (key == "url") { - return getUrl(); - } - return _args[key]; - } - - private: - string getUrl() const{ - auto &allArgs = _args; - CHECK_ARGS("app", "stream"); - - string auth = _args["Authorization"]; // Authorization Bearer - return StrPrinter << "rtc://" << _args["Host"] << "/" << _args["app"] << "/" << _args["stream"] << "?" - << _args.parser.params() + "&session=" + _session_id + (auth.empty() ? "" : ("&Authorization=" + auth)); - } - - private: - ArgsString _args; - std::string _session_id; - }; - api_regist("/index/api/webrtc",[](API_ARGS_STRING_ASYNC){ CHECK_ARGS("type"); auto type = allArgs["type"]; @@ -2107,7 +2154,7 @@ void installWebApi() { CHECK(!offer.empty(), "http body(webrtc offer sdp) is empty"); auto &session = static_cast(sender); - auto args = std::make_shared(allArgs, sender.getIdentifier()); + auto args = std::make_shared>(allArgs, sender.getIdentifier()); WebRtcPluginManager::Instance().negotiateSdp(session, type, *args, [invoker, val, offer, headerOut](const WebRtcInterface &exchanger) mutable { auto &handler = const_cast(exchanger); try { @@ -2130,7 +2177,7 @@ void installWebApi() { auto &session = static_cast(sender); auto location = std::string(session.overSsl() ? "https://" : "http://") + allArgs["host"] + delete_webrtc_url; - auto args = std::make_shared(allArgs, sender.getIdentifier()); + auto args = std::make_shared>(allArgs, sender.getIdentifier()); WebRtcPluginManager::Instance().negotiateSdp(session, type, *args, [invoker, offer, headerOut, location](const WebRtcInterface &exchanger) mutable { auto &handler = const_cast(exchanger); try { @@ -2164,6 +2211,103 @@ void installWebApi() { obj->safeShutdown(SockException(Err_shutdown, "deleted by http api")); invoker(200, headerOut, ""); }); + + // 获取WebRTCProxyPlayer 连接信息 + api_regist("/index/api/getWebrtcProxyPlayerInfo", [](API_ARGS_MAP_ASYNC) { + CHECK_SECRET(); + CHECK_ARGS("key"); + + auto player_proxy = s_player_proxy.find(allArgs["key"]); + if (!player_proxy) { + throw ApiRetException("Stream proxy not found", API::NotFound); + } + + auto media_player = player_proxy->getDelegate(); + if (!media_player) { + throw ApiRetException("Media player not found", API::OtherFailed); + } + + auto webrtc_player_imp = std::dynamic_pointer_cast(media_player); + if (!webrtc_player_imp) { + throw ApiRetException("Stream proxy is not WebRTC type", API::OtherFailed); + } + + auto webrtc_transport = webrtc_player_imp->getWebRtcTransport(); + if (!webrtc_transport) { + throw ApiRetException("WebRTC transport not available", API::OtherFailed); + } + + std::string stream_key = allArgs["key"]; + webrtc_transport->getTransportInfo([val, headerOut, invoker, stream_key](Json::Value transport_info) mutable { + transport_info["stream_key"] = stream_key; + + if (transport_info.isMember("error")) { + Json::Value error_val; + error_val["code"] = API::OtherFailed; + error_val["msg"] = transport_info["error"].asString(); + invoker(200, headerOut, error_val.toStyledString()); + return; + } + + // 成功返回结果 + Json::Value success_val; + success_val["code"] = API::Success; + success_val["msg"] = "success"; + success_val["data"] = transport_info; + invoker(200, headerOut, success_val.toStyledString()); + }); + }); + + api_regist("/index/api/addWebrtcRoomKeeper",[](API_ARGS_MAP_ASYNC){ + CHECK_SECRET(); + CHECK_ARGS("server_host", "server_port", "room_id", "ssl"); + //server_host: 信令服务器host + //server_post: 信令服务器host + //room_id: 注册的id,信令服务器会对该id进行唯一性检查 + addWebrtcRoomKeeper(allArgs["server_host"], allArgs["server_port"], allArgs["room_id"], allArgs["ssl"], + [val, headerOut, invoker](const SockException &ex, const string &key) mutable { + if (ex) { + val["code"] = API::OtherFailed; + val["msg"] = ex.what(); + } else { + val["msg"] = "success"; + val["data"]["room_key"] = key; + } + invoker(200, headerOut, val.toStyledString()); + }); + }); + + api_regist("/index/api/delWebrtcRoomKeeper",[](API_ARGS_MAP_ASYNC){ + CHECK_SECRET(); + CHECK_ARGS("room_key"); + + delWebrtcRoomKeeper(allArgs["room_key"], + [val, headerOut, invoker](const SockException &ex) mutable { + if (ex) { + val["code"] = API::OtherFailed; + val["msg"] = ex.what(); + } + invoker(200, headerOut, val.toStyledString()); + }); + }); + + api_regist("/index/api/listWebrtcRoomKeepers", [](API_ARGS_MAP) { + CHECK_SECRET(); + listWebrtcRoomKeepers([&val](const std::string& key, const WebRtcSignalingPeer::Ptr& p) { + Json::Value item = ToJson(p); + item["room_key"] = key; + val["data"].append(item); + }); + }); + + api_regist("/index/api/listWebrtcRooms", [](API_ARGS_MAP) { + CHECK_SECRET(); + listWebrtcRooms([&val](const std::string& key, const WebRtcSignalingSession::Ptr& p) { + Json::Value item = ToJson(p); + item["room_id"] = key; + val["data"].append(item); + }); + }); #endif #if defined(ENABLE_VERSION) @@ -2201,6 +2345,20 @@ void installWebApi() { // sample_ms设置为0,从配置文件加载;file_repeat可以指定,如果配置文件也指定循环解复用,那么强制开启 [AUTO-TRANSLATED:23e826b4] // sample_ms is set to 0, loaded from the configuration file; file_repeat can be specified, if the configuration file also specifies loop demultiplexing, then force it to be enabled reader->startReadMP4(0, true, allArgs["file_repeat"]); + auto seek_ms = allArgs["seek_ms"].as(); + auto speed = allArgs["speed"].as(); + if (seek_ms || speed) { + auto p = static_pointer_cast(reader); + p->getOwnerPoller(MediaSource::NullMediaSource())->async([seek_ms, speed, p]() { + if (seek_ms) { + p->seekTo(MediaSource::NullMediaSource(), seek_ms); + } + if (speed && speed != 1.0) { + p->speed(MediaSource::NullMediaSource(), speed); + } + }); + } + val["data"]["duration_ms"] = (Json::UInt64)reader->getDemuxer()->getDurationMS(); }); #endif @@ -2249,14 +2407,120 @@ void installWebApi() { } }; - bool flag = NOTICE_EMIT(BroadcastHttpAccessArgs, Broadcast::kBroadcastHttpAccess, allArgs.parser, file_path, false, file_invoker, sender); - if (!flag) { - // 文件下载鉴权事件无人监听,不允许下载 [AUTO-TRANSLATED:5e02f0ce] - // No one is listening to the file download authentication event, download is not allowed - invoker(401, StrCaseMap {}, "None http access event listener"); + try { + CHECK_SECRET(); + // 校验secret成功,文件下载鉴权成功 + file_invoker("", "", 0); + } catch (...) { + bool flag = NOTICE_EMIT(BroadcastHttpAccessArgs, Broadcast::kBroadcastHttpAccess, allArgs.parser, allArgs.parser.url(), file_path, false, file_invoker, sender); + if (!flag) { + // 文件下载鉴权事件无人监听,不允许下载 [AUTO-TRANSLATED:5e02f0ce] + // No one is listening to the file download authentication event, download is not allowed + invoker(401, StrCaseMap {}, "None http access event listener"); + } } }); + api_regist("/index/api/searchOnvifDevice",[](API_ARGS_MAP_ASYNC){ + CHECK_SECRET(); + CHECK_ARGS("timeout_ms"); + + string subnet_prefix = allArgs["subnet_prefix"]; + + auto result = std::make_shared(std::move(val)); + auto complete_token = std::make_shared(nullptr, [result, headerOut, invoker]() { invoker(200, headerOut, result->toStyledString()); }); + auto lam_search = [complete_token, result](const std::map &device_info, const std::string &onvif_url) { + Value obj; + obj["onvif_url"] = onvif_url; + for (auto &pr : device_info) { + obj[pr.first] = pr.second; + } + (*result)["data"].append(std::move(obj)); + //继续等待扫描 + return true; + }; + OnvifSearcher::Instance().sendSearchBroadcast(std::move(subnet_prefix), std::move(lam_search), allArgs["timeout_ms"]); + }); + + api_regist("/index/api/getStreamUrl", [](API_ARGS_MAP_ASYNC) { + CHECK_SECRET(); + CHECK_ARGS("onvif_url"); + + SoapUtil::asyncGetStreamUri(allArgs["onvif_url"],[val, headerOut, allArgs, invoker] + (const SoapErr &err, const SoapUtil::GetStreamUriRetryInvoker &retry_invoker, + int retry_count, const std::string &url) mutable { + if (err && retry_count == 0 && !allArgs["user_name"].empty() /* && + (err.httpCode() == 400 || err.httpCode() == 401)*/) { + //第一次失败,且提供了用户密码,且确定是鉴权失败 + retry_invoker(allArgs["user_name"], allArgs["passwd"]); + return; + } + val["code"] = err ? API::OtherFailed : API::Success; + if (err) { + val["http_code"] = err.httpCode(); + val["msg"] = (string) err; + } else { + val["url"] = url; + } + invoker(200, headerOut, val.toStyledString()); + }); + }); + + api_regist("/index/api/login", [](API_ARGS_MAP) { + auto logined_cookie = HttpCookieManager::Instance().getCookie(kLoginedCookieName, allArgs.getParser().getHeader()); + + CHECK_ARGS("digest"); + GET_CONFIG(std::string, api_secret, API::kSecret); + + auto unlogin_cookie = HttpCookieManager::Instance().getCookie(kUnLoginCookieName, allArgs.getParser().getHeader()); + // MD5("zlmediakit:"+${secret}+":" +${cookie}) + auto digest_ok = unlogin_cookie ? MD5("zlmediakit:" + api_secret + ":" + unlogin_cookie->getCookie()).hexdigest() : ""; + if (!unlogin_cookie || digest_ok != allArgs["digest"]) { + if (!unlogin_cookie) { + unlogin_cookie = HttpCookieManager::Instance().addCookie(kUnLoginCookieName, "", kUnLoginCookieLifeSeconds); + headerOut["Set-Cookie"] = unlogin_cookie->getCookie(kLoginCookiePath); + } + val["cookie"] = unlogin_cookie->getCookie(); + if (logined_cookie) { + // secret校验失败,注销登录 + logined_cookie->setExpired(); + HttpCookieManager::Instance().delCookie(logined_cookie); + headerOut.emplace_force("Set-Cookie", logined_cookie->getCookie(kLoginCookiePath)); + } + throw AuthException("Digest does not match, incorrect secret?", headerOut, val); + } + if (!logined_cookie) { + // 未登陆状态,设置登录成功, cookie保持24小时 + logined_cookie = HttpCookieManager::Instance().addCookie(kLoginedCookieName, "", kLoginedCookieLifeSeconds); + headerOut["Set-Cookie"] = logined_cookie->getCookie(kLoginCookiePath); + } + + // 删除未登录状态的cookie + unlogin_cookie->setExpired(); + HttpCookieManager::Instance().delCookie(unlogin_cookie); + headerOut.emplace_force("Set-Cookie", unlogin_cookie->getCookie(kLoginCookiePath)); + + val["code"] = API::Success; + }); + + api_regist("/index/api/logout", [](API_ARGS_MAP) { + auto logined_cookie = HttpCookieManager::Instance().getCookie(kLoginedCookieName, allArgs.getParser().getHeader()); + if (logined_cookie) { + // 已经登录成功, 删除cookie + logined_cookie->setExpired(); + HttpCookieManager::Instance().delCookie(logined_cookie); + headerOut["Set-Cookie"] = logined_cookie->getCookie(kLoginCookiePath); + } else { + val["msg"] = "You are not logined"; + } + auto unlogin_cookie = HttpCookieManager::Instance().getCookie(kUnLoginCookieName, allArgs.getParser().getHeader()); + if (!unlogin_cookie) { + unlogin_cookie = HttpCookieManager::Instance().addCookie(kUnLoginCookieName, "", kUnLoginCookieLifeSeconds); + headerOut["Set-Cookie"] = unlogin_cookie->getCookie(kLoginCookiePath); + } + val["cookie"] = unlogin_cookie->getCookie(); + }); + #if defined(ENABLE_VIDEOSTACK) && defined(ENABLE_X264) && defined(ENABLE_FFMPEG) VideoStackManager::Instance().loadBgImg("novideo.yuv"); NoticeCenter::Instance().addListener(nullptr, Broadcast::kBroadcastStreamNoneReader, [](BroadcastStreamNoneReaderArgs) { @@ -2301,6 +2565,89 @@ void installWebApi() { invoker(200, headerOut, val.toStyledString()); }); #endif + + // 设置流播放速度 + // Set stream playback speed + api_regist("/index/api/setStreamSpeed", [](API_ARGS_JSON_ASYNC) { + CHECK_SECRET(); + CHECK_ARGS("vhost", "app", "stream", "speed"); + + std::string vhost = allArgs["vhost"]; + std::string app = allArgs["app"]; + std::string stream = allArgs["stream"]; + float speed = allArgs["speed"].as(); + + auto tuple = MediaTuple { vhost, app, stream, "" }; + std::string key = tuple.shortUrl(); + + auto player_proxy = s_player_proxy.find(key); + if (!player_proxy) { + throw ApiRetException("can not find the stream proxy", API::NotFound); + } + + player_proxy->getPoller()->async([=]() mutable { + player_proxy->MediaPlayer::speed(speed); + val["result"] = 0; + val["msg"] = "success"; + val["code"] = API::Success; + invoker(200, headerOut, val.toStyledString()); + }); + }); + + // 暂停/恢复流播放 + // Pause/Resume stream playback + api_regist("/index/api/pauseStream", [](API_ARGS_JSON_ASYNC) { + CHECK_SECRET(); + CHECK_ARGS("vhost", "app", "stream"); + + std::string vhost = allArgs["vhost"]; + std::string app = allArgs["app"]; + std::string stream = allArgs["stream"]; + + auto tuple = MediaTuple { vhost, app, stream, "" }; + std::string key = tuple.shortUrl(); + + auto player_proxy = s_player_proxy.find(key); + if (!player_proxy) { + throw ApiRetException("can not find the stream proxy", API::NotFound); + } + + player_proxy->getPoller()->async([=]() mutable { + player_proxy->MediaPlayer::pause(true); + val["result"] = 0; + val["msg"] = "success"; + val["code"] = API::Success; + invoker(200, headerOut, val.toStyledString()); + }); + }); + + // 跳转到指定位置 + // Seek to specified position + api_regist("/index/api/seekStream", [](API_ARGS_JSON_ASYNC) { + CHECK_SECRET(); + CHECK_ARGS("vhost", "app", "stream"); + + std::string vhost = allArgs["vhost"]; + std::string app = allArgs["app"]; + std::string stream = allArgs["stream"]; + uint32_t pos = allArgs["position"].as(); + + auto tuple = MediaTuple { vhost, app, stream, "" }; + std::string key = tuple.shortUrl(); + + auto player_proxy = s_player_proxy.find(key); + if (!player_proxy) { + throw ApiRetException("can not find the stream proxy", API::NotFound); + } + + player_proxy->getPoller()->async([=]() mutable { + player_proxy->MediaPlayer::seekTo(pos); + val["result"] = 0; + val["msg"] = "success"; + val["code"] = API::Success; + invoker(200, headerOut, val.toStyledString()); + }); + }); } void unInstallWebApi(){ diff --git a/server/WebApi.h b/server/WebApi.h index cc0e95fd..6c405f72 100755 --- a/server/WebApi.h +++ b/server/WebApi.h @@ -1,225 +1,383 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_WEBAPI_H -#define ZLMEDIAKIT_WEBAPI_H - -#include -#include -#include "json/json.h" -#include "Common/Parser.h" -#include "Network/Socket.h" -#include "Http/HttpSession.h" -#include "Common/MultiMediaSourceMuxer.h" - -// 配置文件路径 [AUTO-TRANSLATED:8a373c2f] -// Configuration file path -extern std::string g_ini_file; - -namespace mediakit { -// //////////RTSP服务器配置/////////// [AUTO-TRANSLATED:950e1981] -// //////////RTSP server configuration/////////// -namespace Rtsp { -extern const std::string kPort; -} //namespace Rtsp - -// //////////RTMP服务器配置/////////// [AUTO-TRANSLATED:8de6f41f] -// //////////RTMP server configuration/////////// -namespace Rtmp { -extern const std::string kPort; -} //namespace RTMP -} // namespace mediakit - -namespace API { -typedef enum { - NotFound = -500,//未找到 - Exception = -400,//代码抛异常 - InvalidArgs = -300,//参数不合法 - SqlFailed = -200,//sql执行失败 - AuthFailed = -100,//鉴权失败 - OtherFailed = -1,//业务代码执行失败, - Success = 0//执行成功 -} ApiErr; - -extern const std::string kSecret; -}//namespace API - -class ApiRetException: public std::runtime_error { -public: - ApiRetException(const char *str = "success" ,int code = API::Success):runtime_error(str){ - _code = code; - } - int code(){ return _code; } -private: - int _code; -}; - -class AuthException : public ApiRetException { -public: - AuthException(const char *str):ApiRetException(str,API::AuthFailed){} -}; - -class InvalidArgsException: public ApiRetException { -public: - InvalidArgsException(const char *str):ApiRetException(str,API::InvalidArgs){} -}; - -class SuccessException: public ApiRetException { -public: - SuccessException():ApiRetException("success",API::Success){} -}; - -using ApiArgsType = std::map; - -template -std::string getValue(Args &args, const First &first) { - return args[first]; -} - -template -std::string getValue(Json::Value &args, const First &first) { - return args[first].asString(); -} - -template -std::string getValue(std::string &args, const First &first) { - return ""; -} - -template -std::string getValue(const mediakit::Parser &parser, const First &first) { - auto ret = parser.getUrlArgs()[first]; - if (!ret.empty()) { - return ret; - } - return parser.getHeader()[first]; -} - -template -std::string getValue(mediakit::Parser &parser, const First &first) { - return getValue((const mediakit::Parser &) parser, first); -} - -template -std::string getValue(const mediakit::Parser &parser, Args &args, const First &first) { - auto ret = getValue(args, first); - if (!ret.empty()) { - return ret; - } - return getValue(parser, first); -} - -template -class HttpAllArgs { - mediakit::Parser* _parser = nullptr; - Args* _args = nullptr; -public: - const mediakit::Parser& parser; - Args& args; - - HttpAllArgs(const mediakit::Parser &p, Args &a): parser(p), args(a) {} - - HttpAllArgs(const HttpAllArgs &that): _parser(new mediakit::Parser(that.parser)), - _args(new Args(that.args)), - parser(*_parser), args(*_args) {} - ~HttpAllArgs() { - if (_parser) { - delete _parser; - } - if (_args) { - delete _args; - } - } - - template - toolkit::variant operator[](const Key &key) const { - return (toolkit::variant)getValue(parser, args, key); - } -}; - -using ArgsMap = HttpAllArgs; -using ArgsJson = HttpAllArgs; -using ArgsString = HttpAllArgs; - -#define API_ARGS_MAP toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const ArgsMap &allArgs, Json::Value &val -#define API_ARGS_MAP_ASYNC API_ARGS_MAP, const mediakit::HttpSession::HttpResponseInvoker &invoker -#define API_ARGS_JSON toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const ArgsJson &allArgs, Json::Value &val -#define API_ARGS_JSON_ASYNC API_ARGS_JSON, const mediakit::HttpSession::HttpResponseInvoker &invoker -#define API_ARGS_STRING toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const ArgsString &allArgs, Json::Value &val -#define API_ARGS_STRING_ASYNC API_ARGS_STRING, const mediakit::HttpSession::HttpResponseInvoker &invoker -#define API_ARGS_VALUE sender, headerOut, allArgs, val - -// 注册http请求参数是map类型的http api [AUTO-TRANSLATED:8a273897] -// Register http request parameters as map type http api -void api_regist(const std::string &api_path, const std::function &func); -// 注册http请求参数是map类型,但是可以异步回复的的http api [AUTO-TRANSLATED:9da5d5f5] -// Register http request parameters as map type, but can be replied asynchronously http api -void api_regist(const std::string &api_path, const std::function &func); - -// 注册http请求参数是Json::Value类型的http api(可以支持多级嵌套的json参数对象) [AUTO-TRANSLATED:c4794456] -// Register http request parameters as Json::Value type http api (can support multi-level nested json parameter objects) -void api_regist(const std::string &api_path, const std::function &func); -// 注册http请求参数是Json::Value类型,但是可以异步回复的的http api [AUTO-TRANSLATED:742e57fd] -// Register http request parameters as Json::Value type, but can be replied asynchronously http api -void api_regist(const std::string &api_path, const std::function &func); - -// 注册http请求参数是http原始请求信息的http api [AUTO-TRANSLATED:72d3fe93] -// Register http request parameters as http original request information http api -void api_regist(const std::string &api_path, const std::function &func); -// 注册http请求参数是http原始请求信息的异步回复的http api [AUTO-TRANSLATED:49feefa8] -// Register http request parameters as http original request information asynchronous reply http api -void api_regist(const std::string &api_path, const std::function &func); - -template -bool checkArgs(Args &args, const First &first) { - return !args[first].empty(); -} - -template -bool checkArgs(Args &args, const First &first, const KeyTypes &...keys) { - return checkArgs(args, first) && checkArgs(args, keys...); -} - -// 检查http url中或body中或http header参数是否为空的宏 [AUTO-TRANSLATED:9de001a4] -// Check whether the http url, body or http header parameters are empty -#define CHECK_ARGS(...) \ - if(!checkArgs(allArgs,##__VA_ARGS__)){ \ - throw InvalidArgsException("Required parameter missed: " #__VA_ARGS__); \ - } - -// 检查http参数中是否附带secret密钥的宏,127.0.0.1的ip不检查密钥 [AUTO-TRANSLATED:7546956c] -// Check whether the http parameters contain the secret key, the ip of 127.0.0.1 does not check the key -// 同时检测是否在ip白名单内 [AUTO-TRANSLATED:d12f963d] -// Check whether it is in the ip whitelist at the same time -#define CHECK_SECRET() \ - do { \ - auto ip = sender.get_peer_ip(); \ - if (!HttpFileManager::isIPAllowed(ip)) { \ - throw AuthException("Your ip is not allowed to access the service."); \ - } \ - CHECK_ARGS("secret"); \ - if (api_secret != allArgs["secret"]) { \ - throw AuthException("Incorrect secret"); \ - } \ - } while(false); - -void installWebApi(); -void unInstallWebApi(); - -#if defined(ENABLE_RTPPROXY) -uint16_t openRtpServer(uint16_t local_port, const mediakit::MediaTuple &tuple, int tcp_mode, const std::string &local_ip, bool re_use_port, uint32_t ssrc, int only_track, bool multiplex=false); -#endif - -Json::Value makeMediaSourceJson(mediakit::MediaSource &media); -void getStatisticJson(const std::function &cb); -void addStreamProxy(const mediakit::MediaTuple &tuple, const std::string &url, int retry_count, - const mediakit::ProtocolOption &option, int rtp_type, float timeout_sec, const toolkit::mINI &args, - const std::function &cb); -#endif //ZLMEDIAKIT_WEBAPI_H +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBAPI_H +#define ZLMEDIAKIT_WEBAPI_H + +#include +#include +#include "json/json.h" +#include "Common/Parser.h" +#include "Network/Socket.h" +#include "Http/HttpSession.h" +#include "Common/MultiMediaSourceMuxer.h" + +#if defined(ENABLE_WEBRTC) +#include "webrtc/WebRtcTransport.h" +#endif + +#include "Http/HttpCookieManager.h" + +// 配置文件路径 [AUTO-TRANSLATED:8a373c2f] +// Configuration file path +extern std::string g_ini_file; + +namespace mediakit { +// //////////RTSP服务器配置/////////// [AUTO-TRANSLATED:950e1981] +// //////////RTSP server configuration/////////// +namespace Rtsp { +extern const std::string kPort; +} //namespace Rtsp + +// //////////RTMP服务器配置/////////// [AUTO-TRANSLATED:8de6f41f] +// //////////RTMP server configuration/////////// +namespace Rtmp { +extern const std::string kPort; +} //namespace RTMP +} // namespace mediakit + +namespace API { +typedef enum { + NotFound = -500,//未找到 + Exception = -400,//代码抛异常 + InvalidArgs = -300,//参数不合法 + SqlFailed = -200,//sql执行失败 + AuthFailed = -100,//鉴权失败 + OtherFailed = -1,//业务代码执行失败, + Success = 0//执行成功 +} ApiErr; + +extern const std::string kSecret; +extern const std::string kApiDebug; +} // namespace API + +class ApiRetException : public std::runtime_error { +public: + ApiRetException(const char *str = "success", int code = API::Success, mediakit::StrCaseMap headers = {}, Json::Value body = {}) + : runtime_error(str) { + _code = code; + _headers = std::move(headers); + _body = std::move(body); + } + int code() { return _code; } + + mediakit::StrCaseMap &getHeaders() { return _headers; } + + Json::Value &getBody() { return _body; } + +private: + int _code; + mediakit::StrCaseMap _headers; + Json::Value _body; +}; + +class AuthException : public ApiRetException { +public: + AuthException(const char *str, mediakit::StrCaseMap headers = {}, Json::Value body = {}) + : ApiRetException(str, API::AuthFailed, std::move(headers), std::move(body)) {} +}; + +class InvalidArgsException : public ApiRetException { +public: + InvalidArgsException(const char *str, mediakit::StrCaseMap headers = {}, Json::Value body = {}) + : ApiRetException(str, API::InvalidArgs, std::move(headers), std::move(body)) {} +}; + +class SuccessException : public ApiRetException { +public: + SuccessException(mediakit::StrCaseMap headers = {}, Json::Value body = {}) + : ApiRetException("success", API::Success, std::move(headers), std::move(body)) {} +}; + +using ApiArgsType = std::map; + +template +std::string getValue(Args &args, const Key &key) { + auto it = args.find(key); + if (it == args.end()) { + return ""; + } + return it->second; +} + +template +std::string getValue(Json::Value &args, const Key &key) { + auto value = args.find(key); + if (value == nullptr) { + return ""; + } + return value->asString(); +} + +template +std::string getValue(std::string &args, const Key &key) { + return ""; +} + +template +std::string getValue(const mediakit::Parser &parser, const Key &key) { + auto ret = getValue(parser.getUrlArgs(), key); + if (!ret.empty()) { + return ret; + } + return getValue(parser.getHeader(), key); +} + +template +std::string getValue(mediakit::Parser &parser, const Key &key) { + return getValue((const mediakit::Parser &) parser, key); +} + +template +std::string getValue(const mediakit::Parser &parser, Args &args, const Key &key) { + auto ret = getValue(args, key); + if (!ret.empty()) { + return ret; + } + return getValue(parser, key); +} + +template +class HttpAllArgs { + mediakit::Parser* _parser = nullptr; + Args* _args = nullptr; +public: + const mediakit::Parser& parser; + Args& args; + + HttpAllArgs(const mediakit::Parser &p, Args &a): parser(p), args(a) {} + + HttpAllArgs(const HttpAllArgs &that): _parser(new mediakit::Parser(that.parser)), + _args(new Args(that.args)), + parser(*_parser), args(*_args) {} + ~HttpAllArgs() { + if (_parser) { + delete _parser; + } + if (_args) { + delete _args; + } + } + + template + toolkit::variant operator[](const Key &key) const { + return (toolkit::variant)getValue(parser, args, key); + } + + const Args& getArgs() const { + return args; + } + + const mediakit::Parser &getParser() const { + return parser; + } +}; + +using ArgsMap = HttpAllArgs; +using ArgsJson = HttpAllArgs; +using ArgsString = HttpAllArgs; + +#define API_ARGS_MAP toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const ArgsMap &allArgs, Json::Value &val +#define API_ARGS_MAP_ASYNC API_ARGS_MAP, const mediakit::HttpSession::HttpResponseInvoker &invoker +#define API_ARGS_JSON toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const ArgsJson &allArgs, Json::Value &val +#define API_ARGS_JSON_ASYNC API_ARGS_JSON, const mediakit::HttpSession::HttpResponseInvoker &invoker +#define API_ARGS_STRING toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const ArgsString &allArgs, Json::Value &val +#define API_ARGS_STRING_ASYNC API_ARGS_STRING, const mediakit::HttpSession::HttpResponseInvoker &invoker +#define API_ARGS_VALUE sender, headerOut, allArgs, val + +// 注册http请求参数是map类型的http api [AUTO-TRANSLATED:8a273897] +// Register http request parameters as map type http api +void api_regist(const std::string &api_path, const std::function &func); +// 注册http请求参数是map类型,但是可以异步回复的的http api [AUTO-TRANSLATED:9da5d5f5] +// Register http request parameters as map type, but can be replied asynchronously http api +void api_regist(const std::string &api_path, const std::function &func); + +// 注册http请求参数是Json::Value类型的http api(可以支持多级嵌套的json参数对象) [AUTO-TRANSLATED:c4794456] +// Register http request parameters as Json::Value type http api (can support multi-level nested json parameter objects) +void api_regist(const std::string &api_path, const std::function &func); +// 注册http请求参数是Json::Value类型,但是可以异步回复的的http api [AUTO-TRANSLATED:742e57fd] +// Register http request parameters as Json::Value type, but can be replied asynchronously http api +void api_regist(const std::string &api_path, const std::function &func); + +// 注册http请求参数是http原始请求信息的http api [AUTO-TRANSLATED:72d3fe93] +// Register http request parameters as http original request information http api +void api_regist(const std::string &api_path, const std::function &func); +// 注册http请求参数是http原始请求信息的异步回复的http api [AUTO-TRANSLATED:49feefa8] +// Register http request parameters as http original request information asynchronous reply http api +void api_regist(const std::string &api_path, const std::function &func); + +template +bool checkArgs(Args &args, const Key &key) { + return !args[key].empty(); +} + +template +bool checkArgs(Args &args, const Key &key, const KeyTypes &...keys) { + return checkArgs(args, key) && checkArgs(args, keys...); +} + +// 检查http url中或body中或http header参数是否为空的宏 [AUTO-TRANSLATED:9de001a4] +// Check whether the http url, body or http header parameters are empty +#define CHECK_ARGS(...) \ + if(!checkArgs(allArgs,##__VA_ARGS__)){ \ + throw InvalidArgsException("Required parameter missed: " #__VA_ARGS__); \ + } + +// 检查http参数中是否附带secret密钥的宏,127.0.0.1的ip不检查密钥 [AUTO-TRANSLATED:7546956c] +// Check whether the http parameters contain the secret key, the ip of 127.0.0.1 does not check the key +// 同时检测是否在ip白名单内 [AUTO-TRANSLATED:d12f963d] +// Check whether it is in the ip whitelist at the same time +template +void check_secret(toolkit::SockInfo &sender, mediakit::HttpSession::KeyValue &headerOut, const HttpAllArgs &allArgs, Json::Value &val); +#define CHECK_SECRET() check_secret(sender, headerOut, allArgs, val) + +void installWebApi(); +void unInstallWebApi(); + +#if defined(ENABLE_RTPPROXY) +uint16_t openRtpServer(uint16_t local_port, const mediakit::MediaTuple &tuple, int tcp_mode, const std::string &local_ip, bool re_use_port, uint32_t ssrc, int only_track, bool multiplex=false); +#endif + +Json::Value makeMediaSourceJson(mediakit::MediaSource &media); +ApiArgsType getAllArgs(const mediakit::Parser &parser); +void getStatisticJson(const std::function &cb); +void addStreamProxy(const mediakit::MediaTuple &tuple, const std::string &url, int retry_count, bool force, + const mediakit::ProtocolOption &option, float timeout_sec, const toolkit::mINI &args, + const std::function &cb); + +void updateStreamProxy(const mediakit::MediaTuple &tuple, const std::string &url, const toolkit::mINI &args); + +template +class ServiceController { +public: + using Pointer = std::shared_ptr; + std::unordered_map _map; + mutable std::recursive_mutex _mtx; + + void clear() { + decltype(_map) copy; + { + std::lock_guard lck(_mtx); + copy.swap(_map); + } + } + + size_t erase(const std::string &key) { + Pointer erase_ptr; + { + std::lock_guard lck(_mtx); + auto itr = _map.find(key); + if (itr != _map.end()) { + erase_ptr = std::move(itr->second); + _map.erase(itr); + return 1; + } + } + return 0; + } + + size_t size() { + std::lock_guard lck(_mtx); + return _map.size(); + } + + Pointer find(const std::string &key) const { + std::lock_guard lck(_mtx); + auto it = _map.find(key); + if (it == _map.end()) { + return nullptr; + } + return it->second; + } + + void for_each(const std::function &cb, const std::string &key = {}) { + std::lock_guard lck(_mtx); + if (key.empty()) { + auto it = _map.begin(); + while (it != _map.end()) { + cb(it->first, it->second); + ++it; + } + } else { + auto it = _map.find(key); + if (it == _map.end()) { + throw std::invalid_argument("key not found: " + key); + } + cb(key, it->second); + } + } + + template + Pointer make(const std::string &key, _Args&& ...__args) { + // assert(!find(key)); + + auto server = std::make_shared(std::forward<_Args>(__args)...); + std::lock_guard lck(_mtx); + auto it = _map.emplace(key, server); + assert(it.second); + return server; + } + + template + Pointer makeWithAction(const std::string &key, std::function action, _Args&& ...__args) { + // assert(!find(key)); + + auto server = std::make_shared(std::forward<_Args>(__args)...); + action(server); + std::lock_guard lck(_mtx); + auto it = _map.emplace(key, server); + assert(it.second); + return server; + } + + template + Pointer emplace(const std::string &key, _Args&& ...__args) { + // assert(!find(key)); + + auto server = std::static_pointer_cast(std::forward<_Args>(__args)...); + std::lock_guard lck(_mtx); + auto it = _map.emplace(key, server); + assert(it.second); + return server; + } +}; + +#if defined(ENABLE_WEBRTC) +template +class WebRtcArgsImp : public mediakit::WebRtcArgs { +public: + WebRtcArgsImp(const HttpAllArgs &args, std::string session_id) + : _args(args) + , _session_id(std::move(session_id)) {} + ~WebRtcArgsImp() override = default; + + toolkit::variant operator[](const std::string &key) const override { + if (key == "url") { + return getUrl(); + } + return _args[key]; + } + +private: + std::string getUrl() const { + auto &allArgs = _args; + CHECK_ARGS("app", "stream"); + + return StrPrinter << RTC_SCHEMA << "://" << (_args["Host"].empty() ? DEFAULT_VHOST : _args["Host"].data()) << "/" << _args["app"] << "/" + << _args["stream"] << "?" << _args.getParser().params() + "&session=" + _session_id; + } + +private: + HttpAllArgs _args; + std::string _session_id; +}; +#endif + +#endif //ZLMEDIAKIT_WEBAPI_H diff --git a/server/WebHook.cpp b/server/WebHook.cpp index 503e640e..5447bb6b 100755 --- a/server/WebHook.cpp +++ b/server/WebHook.cpp @@ -18,9 +18,14 @@ #include "Http/HttpRequester.h" #include "Network/Session.h" #include "Rtsp/RtspSession.h" +#include "Player/PlayerProxy.h" #include "WebHook.h" #include "WebApi.h" +#if defined(ENABLE_PYTHON) +#include "pyinvoker.h" +#endif + using namespace std; using namespace Json; using namespace toolkit; @@ -226,9 +231,14 @@ void do_http_hook(const string &url, const ArgsType &body, const function &urls, size_t index, size_t failed_cnt, const MediaInfo &args, const function &closePlayer) { @@ -311,7 +321,7 @@ static void pullStreamFromOrigin(const vector &urls, size_t index, size_ option.enable_hls = option.enable_hls || (args.schema == HLS_SCHEMA); option.enable_mp4 = false; - addStreamProxy(args, url, retry_count, option, Rtsp::RTP_TCP, timeout_sec, mINI{}, [=](const SockException &ex, const string &key) mutable { + addStreamProxy(args, url, retry_count, false, option, timeout_sec, mINI{}, [=](const SockException &ex, const string &key) mutable { if (!ex) { return; } @@ -334,6 +344,10 @@ static mINI jsonToMini(const Value &obj) { mINI ret; if (obj.isObject()) { for (auto it = obj.begin(); it != obj.end(); ++it) { + if (it->isNull()) { + // 忽略null,修复wvp传null覆盖Protocol配置的问题 + continue; + } try { auto str = (*it).asString(); ret[it.name()] = std::move(str); @@ -345,10 +359,29 @@ static mINI jsonToMini(const Value &obj) { return ret; } +ArgsType getRecordInfo(const RecordInfo &info) { + ArgsType body; + body["start_time"] = (Json::UInt64)info.start_time; + body["file_size"] = (Json::UInt64)info.file_size; + body["time_len"] = info.time_len; + body["file_path"] = info.file_path; + body["file_name"] = info.file_name; + body["folder"] = info.folder; + body["url"] = info.url; + dumpMediaTuple(info, body); + return body; +} + void installWebHook() { GET_CONFIG(bool, hook_enable, Hook::kEnable); NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastMediaPublish, [](BroadcastMediaPublishArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_publish(type, args, invoker, sender)) { + return; + } +#endif + GET_CONFIG(string, hook_publish, Hook::kOnPublish); if (!hook_enable || hook_publish.empty()) { invoker("", ProtocolOption()); @@ -378,6 +411,11 @@ void installWebHook() { }); NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastMediaPlayed, [](BroadcastMediaPlayedArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_play(args, invoker, sender)) { + return; + } +#endif GET_CONFIG(string, hook_play, Hook::kOnPlay); if (!hook_enable || hook_play.empty()) { invoker(""); @@ -393,6 +431,11 @@ void installWebHook() { }); NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastFlowReport, [](BroadcastFlowReportArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_flow_report(args, totalBytes, totalDuration, isPlayer, sender)) { + return; + } +#endif GET_CONFIG(string, hook_flowreport, Hook::kOnFlowReport); if (!hook_enable || hook_flowreport.empty()) { return; @@ -414,6 +457,11 @@ void installWebHook() { // 监听kBroadcastOnGetRtspRealm事件决定rtsp链接是否需要鉴权(传统的rtsp鉴权方案)才能访问 [AUTO-TRANSLATED:00dc9fa3] // Listen to the kBroadcastOnGetRtspRealm event to determine whether the rtsp link needs authentication (traditional rtsp authentication scheme) to access NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastOnGetRtspRealm, [](BroadcastOnGetRtspRealmArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_get_rtsp_realm(args, invoker, sender)) { + return; + } +#endif GET_CONFIG(string, hook_rtsp_realm, Hook::kOnRtspRealm); if (!hook_enable || hook_rtsp_realm.empty()) { // 无需认证 [AUTO-TRANSLATED:77728e07] @@ -441,6 +489,11 @@ void installWebHook() { // 监听kBroadcastOnRtspAuth事件返回正确的rtsp鉴权用户密码 [AUTO-TRANSLATED:bcf1754e] // Listen to the kBroadcastOnRtspAuth event to return the correct rtsp authentication username and password NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastOnRtspAuth, [](BroadcastOnRtspAuthArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_rtsp_auth(args, realm, user_name, must_no_encrypt, invoker, sender)) { + return; + } +#endif GET_CONFIG(string, hook_rtsp_auth, Hook::kOnRtspAuth); if (unAuthedRealm == realm || !hook_enable || hook_rtsp_auth.empty()) { // 认证失败 [AUTO-TRANSLATED:70cf56ff] @@ -471,10 +524,6 @@ void installWebHook() { // 监听rtsp、rtmp源注册或注销事件 [AUTO-TRANSLATED:6396afa8] // Listen to rtsp, rtmp source registration or deregistration events NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastMediaChanged, [](BroadcastMediaChangedArgs) { - GET_CONFIG(string, hook_stream_changed, Hook::kOnStreamChanged); - if (!hook_enable || hook_stream_changed.empty()) { - return; - } GET_CONFIG_FUNC(std::set, stream_changed_set, Hook::kStreamChangedSchemas, [](const std::string &str) { std::set ret; auto vec = split(str, "/"); @@ -491,6 +540,15 @@ void installWebHook() { // This protocol registration deregistration event is ignored return; } +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_media_changed(bRegist, sender)) { + return; + } +#endif + GET_CONFIG(string, hook_stream_changed, Hook::kOnStreamChanged); + if (!hook_enable || hook_stream_changed.empty()) { + return; + } ArgsType body; if (bRegist) { @@ -536,6 +594,12 @@ void installWebHook() { return; } +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_stream_not_found(args, sender, closePlayer)) { + return; + } +#endif + GET_CONFIG(string, hook_stream_not_found, Hook::kOnStreamNotFound); if (!hook_enable || hook_stream_not_found.empty()) { return; @@ -559,23 +623,15 @@ void installWebHook() { do_http_hook(hook_stream_not_found, body, res_cb); }); - static auto getRecordInfo = [](const RecordInfo &info) { - ArgsType body; - body["start_time"] = (Json::UInt64)info.start_time; - body["file_size"] = (Json::UInt64)info.file_size; - body["time_len"] = info.time_len; - body["file_path"] = info.file_path; - body["file_name"] = info.file_name; - body["folder"] = info.folder; - body["url"] = info.url; - dumpMediaTuple(info, body); - return body; - }; - #ifdef ENABLE_MP4 // 录制mp4文件成功后广播 [AUTO-TRANSLATED:479ec954] // Broadcast after recording the mp4 file successfully NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastRecordMP4, [](BroadcastRecordMP4Args) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_record_mp4(info)) { + return; + } +#endif GET_CONFIG(string, hook_record_mp4, Hook::kOnRecordMp4); if (!hook_enable || hook_record_mp4.empty()) { return; @@ -587,6 +643,11 @@ void installWebHook() { #endif // ENABLE_MP4 NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastRecordTs, [](BroadcastRecordTsArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_record_ts(info)) { + return; + } +#endif GET_CONFIG(string, hook_record_ts, Hook::kOnRecordTs); if (!hook_enable || hook_record_ts.empty()) { return; @@ -615,13 +676,31 @@ void installWebHook() { }); NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastStreamNoneReader, [](BroadcastStreamNoneReaderArgs) { + auto auto_close = false; + auto muxer = sender.getMuxer(); + if (muxer && muxer->getOption().auto_close) { + auto_close = true; + } + if (!origin_urls.empty() && sender.getOriginType() == MediaOriginType::pull) { // 边沿站无人观看时如果是拉流的则立即停止溯源 [AUTO-TRANSLATED:a1429c77] // If no one is watching at the edge station, stop tracing immediately if it is pulling - sender.close(false); - WarnL << "无人观看主动关闭流:" << sender.getOriginUrl(); + if (!auto_close) { + auto ptr = sender.shared_from_this(); + sender.getOwnerPoller()->async([ptr]() { + ptr->close(false); + }); + WarnL << "Auto close stream when none reader: " << sender.getOriginUrl(); + } return; } + +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_stream_none_reader(sender)) { + return; + } +#endif + GET_CONFIG(string, hook_stream_none_reader, Hook::kOnStreamNoneReader); if (!hook_enable || hook_stream_none_reader.empty()) { return; @@ -633,18 +712,27 @@ void installWebHook() { weak_ptr weakSrc = sender.shared_from_this(); // 执行hook [AUTO-TRANSLATED:1df68201] // Execute hook - do_http_hook(hook_stream_none_reader, body, [weakSrc](const Value &obj, const string &err) { + do_http_hook(hook_stream_none_reader, body, [weakSrc, auto_close](const Value &obj, const string &err) { + if (auto_close) { + // 在上层已经关闭了 + return; + } bool flag = obj["close"].asBool(); auto strongSrc = weakSrc.lock(); if (!flag || !err.empty() || !strongSrc) { return; } - strongSrc->close(false); + strongSrc->getOwnerPoller()->async([strongSrc]() { strongSrc->close(false); }); WarnL << "无人观看主动关闭流:" << strongSrc->getOriginUrl(); }); }); NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastSendRtpStopped, [](BroadcastSendRtpStoppedArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_send_rtp_stopped(sender, ssrc, ex)) { + return; + } +#endif GET_CONFIG(string, hook_send_rtp_stopped, Hook::kOnSendRtpStopped); if (!hook_enable || hook_send_rtp_stopped.empty()) { return; @@ -694,6 +782,11 @@ void installWebHook() { // 追踪用户的目的是为了缓存上次鉴权结果,减少鉴权次数,提高性能 [AUTO-TRANSLATED:22827145] // The purpose of tracking users is to cache the last authentication result, reduce the number of authentication times, and improve performance NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastHttpAccess, [](BroadcastHttpAccessArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_http_access(parser, path, file_path, is_dir, invoker, sender)) { + return; + } +#endif GET_CONFIG(string, hook_http_access, Hook::kOnHttpAccess); if (!hook_enable || hook_http_access.empty()) { // 未开启http文件访问鉴权,那么允许访问,但是每次访问都要鉴权; [AUTO-TRANSLATED:deb3a0ae] @@ -713,6 +806,7 @@ void installWebHook() { body["port"] = sender.get_peer_port(); body["id"] = sender.getIdentifier(); body["path"] = path; + body["file_path"] = file_path; body["is_dir"] = is_dir; body["params"] = parser.params(); for (auto &pr : parser.getHeader()) { @@ -738,6 +832,11 @@ void installWebHook() { }); NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastRtpServerTimeout, [](BroadcastRtpServerTimeoutArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_rtp_server_timeout(local_port, tuple, tcp_mode, re_use_port, ssrc)) { + return; + } +#endif GET_CONFIG(string, rtp_server_timeout, Hook::kOnRtpServerTimeout); if (!hook_enable || rtp_server_timeout.empty()) { return; @@ -754,6 +853,14 @@ void installWebHook() { do_http_hook(rtp_server_timeout, body); }); + NoticeCenter::Instance().addListener(&web_hook_tag, Broadcast::kBroadcastPlayerProxyFailed, [](BroadcastPlayerProxyFailedArgs) { +#if defined(ENABLE_PYTHON) + if (PythonInvoker::Instance().on_player_proxy_failed(sender, ex)) { + return; + } +#endif + }); + // 汇报服务器重新启动 [AUTO-TRANSLATED:bd7d83df] // Report server restart reportServerStarted(); diff --git a/server/main.cpp b/server/main.cpp index 051f2bfe..45349cdc 100644 --- a/server/main.cpp +++ b/server/main.cpp @@ -30,6 +30,8 @@ #if defined(ENABLE_WEBRTC) #include "../webrtc/WebRtcTransport.h" #include "../webrtc/WebRtcSession.h" +#include "../webrtc/WebRtcSignalingSession.h" +#include "../webrtc/IceSession.hpp" #endif #if defined(ENABLE_SRT) @@ -41,9 +43,11 @@ #include "ZLMVersion.h" #endif -#if !defined(_WIN32) +#if defined(ENABLE_PYTHON) +#include "pyinvoker.h" +#endif + #include "System.h" -#endif//!defined(_WIN32) using namespace std; using namespace toolkit; @@ -59,7 +63,7 @@ const string kSSLPort = HTTP_FIELD"sslport"; onceToken token1([](){ mINI::Instance()[kPort] = 80; mINI::Instance()[kSSLPort] = 443; -},nullptr); +}); }//namespace Http // //////////SHELL配置/////////// [AUTO-TRANSLATED:f023ec45] @@ -69,7 +73,7 @@ namespace Shell { const string kPort = SHELL_FIELD"port"; onceToken token1([](){ mINI::Instance()[kPort] = 9000; -},nullptr); +}); } //namespace Shell // //////////RTSP服务器配置/////////// [AUTO-TRANSLATED:950e1981] @@ -81,7 +85,7 @@ const string kSSLPort = RTSP_FIELD"sslport"; onceToken token1([](){ mINI::Instance()[kPort] = 554; mINI::Instance()[kSSLPort] = 332; -},nullptr); +}); } //namespace Rtsp @@ -94,7 +98,7 @@ const string kSSLPort = RTMP_FIELD"sslport"; onceToken token1([](){ mINI::Instance()[kPort] = 1935; mINI::Instance()[kSSLPort] = 19350; -},nullptr); +}); } //namespace RTMP // //////////Rtp代理相关配置/////////// [AUTO-TRANSLATED:7b285587] @@ -104,9 +108,17 @@ namespace RtpProxy { const string kPort = RTP_PROXY_FIELD"port"; onceToken token1([](){ mINI::Instance()[kPort] = 10000; -},nullptr); +}); } //namespace RtpProxy +namespace Python { +#define Python_FIELD "python." +const string kPlugin = Python_FIELD"plugin"; +onceToken token1([](){ + mINI::Instance()[kPlugin] = ""; +}); +} //namespace Python + } // namespace mediakit @@ -259,10 +271,21 @@ int start_main(int argc,char *argv[]) { // Start daemon process System::startDaemon(kill_parent_if_failed); } +#endif //! defined(_WIN32) + + // 设置poller线程数和cpu亲和性,该函数必须在使用ZLToolKit网络相关对象之前调用才能生效 [AUTO-TRANSLATED:7f03a1e5] + // Set the number of poller threads and CPU affinity. This function must be called before using ZLToolKit network related objects to take effect. + // 如果需要调用getSnap和addFFmpegSource接口,可以关闭cpu亲和性 [AUTO-TRANSLATED:7629f7bc] + // If you need to call the getSnap and addFFmpegSource interfaces, you can turn off CPU affinity + + EventPollerPool::setPoolSize(threads); + WorkThreadPool::setPoolSize(threads); + EventPollerPool::enableCpuAffinity(affinity); + WorkThreadPool::enableCpuAffinity(affinity); + // 开启崩溃捕获等 [AUTO-TRANSLATED:9c7c759c] // Enable crash capture, etc. System::systemSetup(); -#endif//!defined(_WIN32) // 启动异步日志线程 [AUTO-TRANSLATED:c93cc6f4] // Start asynchronous log thread @@ -316,15 +339,6 @@ int start_main(int argc,char *argv[]) { uint16_t httpsPort = mINI::Instance()[Http::kSSLPort]; uint16_t rtpPort = mINI::Instance()[RtpProxy::kPort]; - // 设置poller线程数和cpu亲和性,该函数必须在使用ZLToolKit网络相关对象之前调用才能生效 [AUTO-TRANSLATED:7f03a1e5] - // Set the number of poller threads and CPU affinity. This function must be called before using ZLToolKit network related objects to take effect. - // 如果需要调用getSnap和addFFmpegSource接口,可以关闭cpu亲和性 [AUTO-TRANSLATED:7629f7bc] - // If you need to call the getSnap and addFFmpegSource interfaces, you can turn off CPU affinity - - EventPollerPool::setPoolSize(threads); - WorkThreadPool::setPoolSize(threads); - EventPollerPool::enableCpuAffinity(affinity); - // 简单的telnet服务器,可用于服务器调试,但是不能使用23端口,否则telnet上了莫名其妙的现象 [AUTO-TRANSLATED:f9324c6e] // Simple telnet server, can be used for server debugging, but cannot use port 23, otherwise telnet will have inexplicable phenomena // 测试方法:telnet 127.0.0.1 9000 [AUTO-TRANSLATED:de0ac883] @@ -369,8 +383,17 @@ int start_main(int argc,char *argv[]) { } return Socket::createSocket(new_poller, false); }); + + auto signaleSrv = std::make_shared(); + auto signalsSrv = std::make_shared(); + auto iceTcpSrv = std::make_shared(); + auto iceSrv = std::make_shared(); uint16_t rtcPort = mINI::Instance()[Rtc::kPort]; uint16_t rtcTcpPort = mINI::Instance()[Rtc::kTcpPort]; + uint16_t signalingPort = mINI::Instance()[Rtc::kSignalingPort]; + uint16_t signalSslPort = mINI::Instance()[Rtc::kSignalingSslPort]; + uint16_t icePort = mINI::Instance()[Rtc::kIcePort]; + uint16_t iceTcpPort = mINI::Instance()[Rtc::kIceTcpPort]; #endif//defined(ENABLE_WEBRTC) @@ -436,6 +459,12 @@ int start_main(int argc,char *argv[]) { if (rtcTcpPort) { rtcSrv_tcp->start(rtcTcpPort, listen_ip);} + //webrtc 信令服务器 + if (signalingPort) { signaleSrv->start(signalingPort);} + if (signalSslPort) { signalsSrv->start(signalSslPort);} + //STUN/TURN服务 + if (icePort) { iceSrv->start(icePort);} + if (iceTcpPort) { iceTcpSrv->start(iceTcpPort);} #endif//defined(ENABLE_WEBRTC) #if defined(ENABLE_SRT) @@ -478,12 +507,25 @@ int start_main(int argc,char *argv[]) { g_reload_certificates(); }); #endif + +#if defined(ENABLE_PYTHON) + // 初始化python解释器 + auto &ref = PythonInvoker::Instance(); + auto py_plugin = mINI::Instance()[Python::kPlugin]; + if (!py_plugin.empty()) { + ref.load(py_plugin); + } +#endif sem.wait(); } unInstallWebApi(); unInstallWebHook(); onProcessExited(); +#if defined(ENABLE_PYTHON) + PythonInvoker::release(); +#endif + // 休眠1秒再退出,防止资源释放顺序错误 [AUTO-TRANSLATED:1b11a74f] // sleep for 1 second before exiting, to prevent resource release order errors InfoL << "程序退出中,请等待..."; diff --git a/server/pyinvoker.cpp b/server/pyinvoker.cpp new file mode 100644 index 00000000..ccd0ca7f --- /dev/null +++ b/server/pyinvoker.cpp @@ -0,0 +1,794 @@ +#if defined(ENABLE_PYTHON) + +#include "pyinvoker.h" + +#include +#include +#include +#include +#include +#include "WebApi.h" +#include "WebHook.h" +#include "Util/util.h" +#include "Util/File.h" +#include "Common/Parser.h" +#include "Common/macros.h" +#include "Http/HttpSession.h" +#include "Poller/EventPoller.h" +#include "WebApi.h" + +using namespace toolkit; +using namespace mediakit; + +extern ArgsType make_json(const MediaInfo &args); +extern void fillSockInfo(Json::Value & val, SockInfo* info); +extern ArgsType getRecordInfo(const RecordInfo &info); +extern std::string g_ini_file; + +template +typename std::enable_if::value, py::capsule>::type to_python(const T &obj) { + static auto name_str = toolkit::demangle(typeid(T).name()); + auto p = new toolkit::Any(std::make_shared(obj)); + return py::capsule(p, name_str.data(), [](PyObject *capsule) { + auto p = reinterpret_cast(PyCapsule_GetPointer(capsule, name_str.data())); + delete p; + TraceL << "delete " << name_str << "(" << p << ")"; + }); +} + +template +typename std::enable_if::value, py::capsule>::type to_python(const T &obj) { + static auto name_str = toolkit::demangle(typeid(T).name()); + auto p = new toolkit::Any(std::shared_ptr(const_cast(&obj), [](T *) {})); + return py::capsule(p, name_str.data(), [](PyObject *capsule) { + auto p = reinterpret_cast(PyCapsule_GetPointer(capsule, name_str.data())); + delete p; + TraceL << "unref " << name_str << "(" << p << ")"; + }); +} + +static py::dict jsonToPython(const Json::Value &obj) { + py::dict ret; + if (obj.isObject()) { + for (auto it = obj.begin(); it != obj.end(); ++it) { + if (it->isNull()) { + // 忽略null,修复wvp传null覆盖Protocol配置的问题 + continue; + } + try { + auto str = (*it).asString(); + ret[it.name().data()] = std::move(str); + } catch (std::exception &) { + WarnL << "Json is not convertible to string, key: " << it.name() << ", value: " << (*it); + } + } + } + return ret; +} + +py::dict to_python(const MediaInfo &args) { + auto json = make_json(args); + return jsonToPython(json); +} + +py::dict to_python(const SockInfo &info) { + Json::Value json; + fillSockInfo(json, const_cast(&info)); + return jsonToPython(json); +} + +py::dict to_python(const RecordInfo &info) { + return jsonToPython(getRecordInfo(info)); +} + +template +std::shared_ptr to_python_ref(const T &t) { + return std::shared_ptr(const_cast(&t), py::nodelete()); +} + +template +T &to_native(const py::capsule &cap) { + static auto name_str = toolkit::demangle(typeid(T).name()); + if (std::string(cap.name()) != name_str) { + throw std::runtime_error("Invalid capsule name!"); + } + auto any = static_cast(cap.get_pointer()); + return any->get(); +} + +mINI to_native(const py::dict &opt) { + mINI ret; + for (auto &item : opt) { + // 转换为字符串(允许 int/float/bool 等) + ret.emplace(py::str(item.first).cast(), py::str(item.second).cast()); + } + return ret; +} + +void python_api_debug(const Parser &parser, const std::string &body) { + GET_CONFIG(bool, api_debug, API::kApiDebug); + if (!api_debug) { + return; + } + ssize_t size = body.size(); + LogContextCapture log(getLogger(), toolkit::LDebug, __FILE__, "python http api debug", __LINE__); + log << "\r\n# request:\r\n" << parser.method() << " " << parser.fullUrl() << "\r\n"; + log << "# header:\r\n"; + + for (auto &pr : parser.getHeader()) { + log << pr.first << " : " << pr.second << "\r\n"; + } + + auto &content = parser.content(); + log << "# content:\r\n" << (content.size() > 4 * 1024 ? content.substr(0, 4 * 1024) : content) << "\r\n"; + + if (size > 0 && size < 4 * 1024) { + log << "# response:\r\n" << body << "\r\n"; + } else { + log << "# response size:" << size << "\r\n"; + } +} + +void handle_http_request(const py::object &check_route, const py::object &submit_coro, const Parser &parser, const HttpSession::HttpResponseInvoker &invoker, bool &consumed, toolkit::SockInfo &sender) { + py::gil_scoped_acquire guard; + + py::dict scope; + scope["type"] = "http"; + scope["http_version"] = "1.1"; + scope["method"] = parser.method(); + scope["path"] = parser.url(); + scope["query_string"] = parser.params(); + py::list hdrs; + for (auto &kv : parser.getHeader()) { + // Starlette/ASGI 规范要求 headers 的 key 必须全小写字节串 + hdrs.append(py::make_tuple(py::bytes(toolkit::strToLower(kv.first.data())), py::bytes(kv.second))); + } + scope["headers"] = hdrs; + + bool ok = check_route(scope).cast(); + if (!ok) { + return; + } + consumed = true; + + Json::Value val; + HttpSession::KeyValue headerOut; + // http api被python拦截了,再api统一鉴权 + try { + auto args = getAllArgs(parser); + auto allArgs = ArgsMap(parser, args); + // Python接口要求登录鉴权 + CHECK_SECRET(); + } catch (std::exception &ex) { + auto ex1 = dynamic_cast(&ex); + if (ex1) { + val["code"] = ex1->code(); + } else { + val["code"] = API::Exception; + } + val["msg"] = ex.what(); + headerOut["Content-Type"] = "application/json"; + invoker(200, headerOut, val.toStyledString()); + return; + } + + StrCaseMap resp_headers; + std::string resp_body; + int status = 500; + auto send = py::cpp_function([parser, invoker, status, resp_body, resp_headers](const py::dict &msg) mutable { + auto type = msg["type"].cast(); + if (type == "http.response.start") { + status = msg["status"].cast(); + for (auto tup : msg["headers"].cast()) { + auto t = tup.cast(); + resp_headers[t[0].cast()] = t[1].cast(); + } + return; + } + + if (type == "http.response.body") { + resp_body += msg["body"].cast(); + // 💥 只在 more_body=False 时回调 + bool more = msg.contains("more_body") && msg["more_body"].cast(); + if (!more) { + python_api_debug(parser, resp_body); + invoker(status, resp_headers, resp_body); + } + } + }); + + submit_coro(scope, py::bytes(parser.content()), send); +} + +class MuxerDelegatePython : public MediaSinkInterface { +public: + MuxerDelegatePython(py::object object) { + _py_muxer = std::move(object); + _input_frame = _py_muxer.attr("inputFrame"); + _add_track = _py_muxer.attr("addTrack"); + _add_track_completed = _py_muxer.attr("addTrackCompleted"); + } + + ~MuxerDelegatePython() override { + py::gil_scoped_acquire guard; + try { + auto destroy = _py_muxer.attr("destroy"); + destroy(); + destroy = py::function(); + } catch (std::exception &ex) { + ErrorL << "destroy python muxer failed: " << ex.what(); + } + _input_frame = py::function(); + _add_track = py::function(); + _add_track_completed = py::function(); + _py_muxer = py::function(); + } + + bool addTrack(const Track::Ptr &track) override { + py::gil_scoped_acquire guard; + return _add_track ? _add_track(track).cast() : false; + } + + void addTrackCompleted() override { + py::gil_scoped_acquire guard; + if (_add_track_completed) { + _add_track_completed(); + } + } + + bool inputFrame(const Frame::Ptr &frame) override { + py::gil_scoped_acquire guard; + return _input_frame ? _input_frame(frame).cast() : false; + } + +private: + py::object _py_muxer; + py::function _input_frame; + py::function _add_track; + py::function _add_track_completed; +}; + +PYBIND11_EMBEDDED_MODULE(mk_loader, m) { + m.def("log", [](int lev, const char *file, int line, const char *func, const char *content) { + py::gil_scoped_release release; + LoggerWrapper::printLog(::toolkit::getLogger(), lev, file, func, line, content); + }); + + m.def("get_config", [](const std::string &key) -> std::string { + py::gil_scoped_release release; + const auto it = mINI::Instance().find(key); + if (it != mINI::Instance().end()) { + return it->second; + } + return ""; + }); + + m.def("get_full_path", [](const std::string &path, const std::string ¤t_path) -> std::string { + py::gil_scoped_release release; + switch (path.front()) { + case '/': + case '\\': return path; + default: return File::absolutePath(path, current_path); + } + }, py::arg("path"), py::arg("current_path") = ""); + + m.def("set_config", [](const std::string &key, const std::string &value) -> bool { + py::gil_scoped_release release; + mINI::Instance()[key]= value; + return true; + }); + + m.def("update_config", []() { + NOTICE_EMIT(BroadcastReloadConfigArgs, Broadcast::kBroadcastReloadConfig); + mINI::Instance().dumpFile(g_ini_file); + return true; + }); + + m.def("publish_auth_invoker_do", [](const py::capsule &cap, const std::string &err, const py::dict &opt) { + ProtocolOption option; + option.load(to_native(opt)); + // 执行c++代码时释放gil锁 + py::gil_scoped_release release; + auto &invoker = to_native(cap); + invoker(err, option); + }); + + m.def("play_auth_invoker_do", [](const py::capsule &cap, const std::string &err) { + // 执行c++代码时释放gil锁 + py::gil_scoped_release release; + auto &invoker = to_native(cap); + invoker(err); + }); + + m.def("rtsp_get_realm_invoker_do", [](const py::capsule &cap, const std::string &realm) { + // 执行c++代码时释放gil锁 + py::gil_scoped_release release; + auto &invoker = to_native(cap); + invoker(realm); + }); + + m.def("rtsp_auth_invoker_do", [](const py::capsule &cap, bool encrypted, const std::string &pwd_or_md5) { + // 执行c++代码时释放gil锁 + py::gil_scoped_release release; + auto &invoker = to_native(cap); + invoker(encrypted, pwd_or_md5); + }); + + m.def("close_player_invoker_do", [](const py::capsule &cap) { + // 执行c++代码时释放gil锁 + py::gil_scoped_release release; + auto &invoker = to_native>(cap); + invoker(); + }); + + m.def("http_access_invoker_do", [](const py::capsule &cap, const std::string &errMsg,const std::string &accessPath, int cookieLifeSecond) { + // 执行c++代码时释放gil锁 + py::gil_scoped_release release; + auto &invoker = to_native(cap); + invoker(errMsg, accessPath, cookieLifeSecond); + }); + + // add_stream_proxy(vhost, app, stream, url, cb, retry_count=-1, force=False, + // rtp_type=0, timeout_sec=0, opt={}) + // opt 字典可包含 ProtocolOption 的所有字段,以及其他透传给 Player 的 key-value 参数 + m.def("add_stream_proxy", + [](const std::string &vhost, const std::string &app, const std::string &stream, + const std::string &url, const py::object &cb, + int retry_count, bool force, float timeout_sec, + const py::dict &opt) { + mINI args = to_native(opt); + ProtocolOption option; + option.load(args); + MediaTuple tuple { vhost.empty() ? DEFAULT_VHOST : vhost, app, stream, "" }; + + // 用 shared_ptr 包裹 py::object,使其析构(dec_ref)可在受控环境下执行。 + // 必须在 GIL 持有期间创建该 shared_ptr(此处仍在 GIL 内)。 + // 自定义 deleter 保证即使在非 Python 线程析构时也会先获取 GIL。 + auto cb_ptr = std::shared_ptr( + new py::object(cb), + [](py::object *p) { + // dec_ref / 析构 py::object 需要 GIL + py::gil_scoped_acquire guard; + delete p; + } + ); + + py::gil_scoped_release release; + EventPollerPool::Instance().getPoller(false)->async([=]() mutable { + addStreamProxy(tuple, url, retry_count, force, option, timeout_sec, args, + [cb_ptr](const SockException &ex, const std::string &key) { + // cb_ptr 按值捕获(shared_ptr 的拷贝,纯 C++ 操作,无需 GIL) + // inc_ref/dec_ref/调用 Python 对象均在 gil_scoped_acquire 保护下进行 + py::gil_scoped_acquire guard; + try { + (*cb_ptr)(ex ? ex.what() : "", key); + } catch (py::error_already_set &e) { + WarnL << "Python exception in add_stream_proxy callback: " << e.what(); + } + // cb_ptr 在此析构(局部副本),dec_ref 由自定义 deleter 在 GIL 下执行 + }); + }); + }, + py::arg("vhost"), py::arg("app"), py::arg("stream"), py::arg("url"), py::arg("cb"), + py::arg("retry_count") = -1, py::arg("force") = false, + py::arg("timeout_sec") = 0.0f, py::arg("opt") = py::dict() + ); + + // update_stream_proxy(vhost, app, stream, url, opt={}) + // 更新已有拉流代理的 url 和参数,流不存在时抛出异常 + m.def("update_stream_proxy", + [](const std::string &vhost, const std::string &app, const std::string &stream, + const std::string &url, const py::dict &opt) { + mINI args = to_native(opt); + MediaTuple tuple { vhost.empty() ? DEFAULT_VHOST : vhost, app, stream, "" }; + py::gil_scoped_release release; + updateStreamProxy(tuple, url, args); + }, + py::arg("vhost"), py::arg("app"), py::arg("stream"), py::arg("url"), + py::arg("opt") = py::dict() + ); + + m.def("set_fastapi", [](const py::object &check_route, const py::object &submit_coro) { + static void *fastapi_tag = nullptr; + NoticeCenter::Instance().delListener(&fastapi_tag, Broadcast::kBroadcastHttpRequest); + NoticeCenter::Instance().addListener(&fastapi_tag, Broadcast::kBroadcastHttpRequest, [check_route, submit_coro](BroadcastHttpRequestArgs) { + handle_http_request(check_route, submit_coro, parser, invoker, consumed, sender); + }); + }); + + py::enum_(m, "TrackType") + .value("Invalid", TrackInvalid) + .value("Video", TrackVideo) + .value("Audio", TrackAudio) + .value("Title", TrackTitle) + .value("Application", TrackApplication) + .export_values(); + + py::class_(m, "MediaSource") + .def("getSchema", &MediaSource::getSchema) + .def("getUrl", &MediaSource::getUrl) + .def("getMediaTuple", &MediaSource::getMediaTuple) + .def("getTimeStamp", &MediaSource::getTimeStamp) + .def("setTimeStamp", &MediaSource::setTimeStamp) + .def("getBytesSpeed", &MediaSource::getBytesSpeed) + .def("getTotalBytes", &MediaSource::getTotalBytes) + .def("getCreateStamp", &MediaSource::getCreateStamp) + .def("getAliveSecond", &MediaSource::getAliveSecond) + .def("readerCount", &MediaSource::readerCount) + .def("totalReaderCount", &MediaSource::totalReaderCount) + .def("getOriginType", &MediaSource::getOriginType) + .def("getOriginUrl", &MediaSource::getOriginUrl) + .def("getOriginSock", &MediaSource::getOriginSock) + .def("seekTo", &MediaSource::seekTo) + .def("pause", &MediaSource::pause) + .def("speed", &MediaSource::speed) + .def("close", &MediaSource::close) + .def("setupRecord", &MediaSource::setupRecord) + .def("isRecording", &MediaSource::isRecording) + .def("stopSendRtp", &MediaSource::stopSendRtp) + .def("getLossRate", &MediaSource::getLossRate) + .def("getMuxer", &MediaSource::getMuxer); + + py::class_>(m, "MediaTuple") + .def_readwrite("vhost", &MediaTuple::vhost) + .def_readwrite("app", &MediaTuple::app) + .def_readwrite("stream", &MediaTuple::stream) + .def_readwrite("params", &MediaTuple::params) + .def("shortUrl", &MediaTuple::shortUrl); + + py::class_>(m, "SockException").def("what", &SockException::what).def("code", &SockException::getErrCode); + + py::class_>(m, "Parser") + .def("method", &Parser::method) + .def("url", &Parser::url) + .def("status", &Parser::status) + .def("fullUrl", &Parser::fullUrl) + .def("protocol", &Parser::protocol) + .def("statusStr", &Parser::statusStr) + .def("content", &Parser::content) + .def("params", &Parser::params) + .def("getHeader", [](Parser *thiz) { + py::dict ret; + for (auto &pr : thiz->getHeader()) { + ret[pr.first.data()] = pr.second; + } + return ret; + }); + + py::enum_(m, "RecordType") + .value("hls", Recorder::type_hls) + .value("mp4", Recorder::type_mp4) + .value("hls_fmp4", Recorder::type_hls_fmp4) + .value("fmp4", Recorder::type_fmp4) + .value("ts", Recorder::type_ts) + .export_values(); + +#define OPT(key) .def_readwrite(#key, &ProtocolOption::key) + py::class_>(m, "ProtocolOption") OPT_VALUE(OPT); +#undef OPT + + py::class_>(m, "MultiMediaSourceMuxer") + .def("totalReaderCount", static_cast(&MultiMediaSourceMuxer::totalReaderCount)) + .def("isEnabled", &MultiMediaSourceMuxer::isEnabled) + .def("setupRecord", &MultiMediaSourceMuxer::setupRecord) + .def("startRecord", &MultiMediaSourceMuxer::startRecord) + .def("isRecording", &MultiMediaSourceMuxer::isRecording) + .def("startSendRtp", &MultiMediaSourceMuxer::startSendRtp) + .def("stopSendRtp", &MultiMediaSourceMuxer::stopSendRtp) + .def("getOption", &MultiMediaSourceMuxer::getOption) + .def("getMediaTuple", &MultiMediaSourceMuxer::getMediaTuple); + + py::class_(m, "Track") + .def("getCodecId", &Track::getCodecId) + .def("getCodecName", &Track::getCodecName) + .def("getTrackType", &Track::getTrackType) + .def("getTrackTypeStr", &Track::getTrackTypeStr) + .def("setIndex", &Track::setIndex) + .def("getIndex", &Track::getIndex) + .def("getVideoKeyFrames", &Track::getVideoKeyFrames) + .def("getFrames", &Track::getFrames) + .def("getVideoGopSize", &Track::getVideoGopSize) + .def("getVideoGopInterval", &Track::getVideoGopInterval) + .def("getDuration", &Track::getDuration) + .def("ready", &Track::ready) + .def("update", &Track::update) + .def("getSdp", &Track::getSdp) + .def("getExtraData", &Track::getExtraData) + .def("setExtraData", &Track::setExtraData) + .def("getBitRate", &Track::getBitRate) + .def("setBitRate", &Track::setBitRate) + .def("getVideoHeight",[](Track *thiz) { + auto ptr = dynamic_cast(thiz); + return ptr ? ptr->getVideoHeight() : 0; + }) + .def("getVideoWidth", [](Track *thiz) { + auto ptr = dynamic_cast(thiz); + return ptr ? ptr->getVideoWidth() : 0; + }) + .def("getVideoFps", [](Track *thiz) { + auto ptr = dynamic_cast(thiz); + return ptr ? ptr->getVideoFps() : 0; + }) + .def("getAudioSampleRate",[](Track *thiz) { + auto ptr = dynamic_cast(thiz); + return ptr ? ptr->getAudioSampleRate() : 0; + }) + .def("getAudioSampleBit", [](Track *thiz) { + auto ptr = dynamic_cast(thiz); + return ptr ? ptr->getAudioSampleBit() : 0; + }) + .def("getAudioChannel", [](Track *thiz) { + auto ptr = dynamic_cast(thiz); + return ptr ? ptr->getAudioChannel() : 0; + }); + + py::class_(m, "Frame") + .def("data", &Frame::data) + .def("size", &Frame::size) + .def("toString", &Frame::toString) + .def("getCapacity", &Frame::getCapacity) + .def("getCodecId", &Frame::getCodecId) + .def("getCodecName", &Frame::getCodecName) + .def("getTrackType", &Frame::getTrackType) + .def("getTrackTypeStr", &Frame::getTrackTypeStr) + .def("setIndex", &Frame::setIndex) + .def("getIndex", &Frame::getIndex) + .def("dts", &Frame::dts) + .def("pts", &Frame::pts) + .def("prefixSize", &Frame::prefixSize) + .def("keyFrame", &Frame::keyFrame) + .def("configFrame", &Frame::configFrame) + .def("cacheAble", &Frame::cacheAble) + .def("dropAble", &Frame::dropAble) + .def("decodeAble", &Frame::decodeAble); +} + +namespace mediakit { + +inline bool set_env(const char *name, const char *value) { +#if defined(_WIN32) + std::string env_str = std::string(name) + "=" + value; + return _putenv(env_str.c_str()) == 0; +#else + return setenv(name, value, 1) == 0; // overwrite = 1 +#endif +} + +bool set_python_path() { + const char *env_var = std::getenv("PYTHONPATH"); + if (env_var && *env_var) { + PrintI("PYTHONPATH is already set to: %s", env_var); + return false; + } + auto default_path = exeDir() + "/python:" + exeDir() + "/pymkui/backend"; + // 1 表示覆盖已存在的值 + if (!set_env("PYTHONPATH", default_path.data())) { + PrintW("Failed to set PYTHONPATH"); + return false; + } + PrintI("PYTHONPATH was not set. Set to default: %s", default_path.data()); + return true; +} + +static std::shared_ptr g_instance; + +PythonInvoker &PythonInvoker::Instance() { + static toolkit::onceToken s_token([]() { + g_instance.reset(new PythonInvoker); + }); + + return *g_instance; +} + +void PythonInvoker::release() { + g_instance = nullptr; +} + +PythonInvoker::PythonInvoker() { + // 确保日志一直可用 + _logger = Logger::Instance().shared_from_this(); + set_python_path(); // 确保 PYTHONPATH 在第一次调用时设置 + _interpreter = new py::scoped_interpreter; + _rel = new py::gil_scoped_release; + + NoticeCenter::Instance().addListener(this, Broadcast::kBroadcastReloadConfig, [this] (BroadcastReloadConfigArgs) { + py::gil_scoped_acquire guard; + if (_on_reload_config) { + _on_reload_config(); + } + }); + + NoticeCenter::Instance().addListener(this, Broadcast::kBroadcastCreateMuxer, [this](BroadcastCreateMuxerArgs) { + py::gil_scoped_acquire guard; + if (_on_create_muxer) { + auto py_muxer = _on_create_muxer(to_python_ref(sender)); + if (py_muxer && !py_muxer.is_none()) { + delegate = std::make_shared(std::move(py_muxer)); + } + } + }); +} + +PythonInvoker::~PythonInvoker() { + NoticeCenter::Instance().delListener(this, Broadcast::kBroadcastReloadConfig); + { + py::gil_scoped_acquire gil; // 加锁 + if (_on_exit) { + _on_exit(); + } + _on_exit = py::function(); + _on_publish = py::function(); + _on_play = py::function(); + _on_flow_report = py::function(); + _on_reload_config = py::function(); + _on_media_changed = py::function(); + _on_player_proxy_failed = py::function(); + _on_get_rtsp_realm = py::function(); + _on_rtsp_auth = py::function(); + _on_stream_not_found = py::function(); + _on_record_mp4 = py::function(); + _on_record_ts = py::function(); + _on_stream_none_reader = py::function(); + _on_send_rtp_stopped = py::function(); + _on_http_access = py::function(); + _on_rtp_server_timeout = py::function(); + _on_create_muxer = py::function(); + _module = py::module(); + } + delete _rel; + delete _interpreter; +} + +#define GET_FUNC(instance, name) \ + if (hasattr(instance, #name)) { \ + _##name = instance.attr(#name); \ + } + +void PythonInvoker::load(const std::string &module_name) { + try { + py::gil_scoped_acquire gil; // 加锁 + _module = py::module::import(module_name.c_str()); + GET_FUNC(_module, on_exit); + GET_FUNC(_module, on_publish); + GET_FUNC(_module, on_play); + GET_FUNC(_module, on_flow_report); + GET_FUNC(_module, on_reload_config); + GET_FUNC(_module, on_media_changed); + GET_FUNC(_module, on_player_proxy_failed); + GET_FUNC(_module, on_get_rtsp_realm); + GET_FUNC(_module, on_rtsp_auth); + GET_FUNC(_module, on_stream_not_found); + GET_FUNC(_module, on_record_mp4); + GET_FUNC(_module, on_record_ts); + GET_FUNC(_module, on_stream_none_reader); + GET_FUNC(_module, on_send_rtp_stopped); + GET_FUNC(_module, on_http_access); + GET_FUNC(_module, on_rtp_server_timeout); + GET_FUNC(_module, on_create_muxer); + + if (hasattr(_module, "on_start")) { + py::object on_start = _module.attr("on_start"); + if (on_start) { + on_start(); + } + } + } catch (py::error_already_set &e) { + PrintE("Python exception:%s", e.what()); + } +} + +bool PythonInvoker::on_publish(BroadcastMediaPublishArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_publish) { + return false; + } + return _on_publish(getOriginTypeString(type), to_python(args), to_python(invoker), to_python(sender)).cast(); +} + +bool PythonInvoker::on_play(BroadcastMediaPlayedArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_play) { + return false; + } + return _on_play(to_python(args), to_python(invoker), to_python(sender)).cast(); +} + +bool PythonInvoker::on_flow_report(BroadcastFlowReportArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_flow_report) { + return false; + } + return _on_flow_report(to_python(args), totalBytes, totalDuration, isPlayer, to_python(sender)).cast(); +} + +bool PythonInvoker::on_media_changed(BroadcastMediaChangedArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_media_changed) { + return false; + } + return _on_media_changed(bRegist, to_python_ref(sender)).cast(); +} + +bool PythonInvoker::on_player_proxy_failed(BroadcastPlayerProxyFailedArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_player_proxy_failed) { + return false; + } + return _on_player_proxy_failed(sender.getUrl(), to_python_ref(sender.getMediaTuple()), to_python_ref(ex)).cast(); +} + +bool PythonInvoker::on_get_rtsp_realm(BroadcastOnGetRtspRealmArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_get_rtsp_realm) { + return false; + } + return _on_get_rtsp_realm(to_python(args), to_python(invoker), to_python(sender)).cast(); +} + +bool PythonInvoker::on_rtsp_auth(BroadcastOnRtspAuthArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_rtsp_auth) { + return false; + } + return _on_rtsp_auth(to_python(args), realm, user_name, must_no_encrypt, to_python(invoker), to_python(sender)).cast(); +} + +bool PythonInvoker::on_stream_not_found(BroadcastNotFoundStreamArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_stream_not_found) { + return false; + } + return _on_stream_not_found(to_python(args), to_python(sender), to_python(closePlayer)).cast(); +} + +bool PythonInvoker::on_record_mp4(BroadcastRecordMP4Args) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_record_mp4) { + return false; + } + return _on_record_mp4(to_python(info)).cast(); +} + +bool PythonInvoker::on_record_ts(BroadcastRecordTsArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_record_ts) { + return false; + } + return _on_record_ts(to_python(info)).cast(); +} + +bool PythonInvoker::on_stream_none_reader(BroadcastStreamNoneReaderArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_stream_none_reader) { + return false; + } + return _on_stream_none_reader(to_python_ref(sender)).cast(); +} + +bool PythonInvoker::on_send_rtp_stopped(BroadcastSendRtpStoppedArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_send_rtp_stopped) { + return false; + } + return _on_send_rtp_stopped(to_python_ref(sender), ssrc, to_python_ref(ex)).cast(); +} + +bool PythonInvoker::on_http_access(BroadcastHttpAccessArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_http_access) { + return false; + } + return _on_http_access(to_python_ref(parser), path, file_path, is_dir, to_python(invoker), to_python(sender)).cast(); +} + +bool PythonInvoker::on_rtp_server_timeout(BroadcastRtpServerTimeoutArgs) const { + py::gil_scoped_acquire gil; // 确保在 Python 调用期间持有 GIL + if (!_on_rtp_server_timeout) { + return false; + } + return _on_rtp_server_timeout(local_port, to_python_ref(tuple), tcp_mode, re_use_port, ssrc).cast(); +} + +} // namespace mediakit + +#endif diff --git a/server/pyinvoker.h b/server/pyinvoker.h new file mode 100644 index 00000000..3d9a55e6 --- /dev/null +++ b/server/pyinvoker.h @@ -0,0 +1,95 @@ + +#ifndef PYINVOKER_H +#define PYINVOKER_H + +#if defined(ENABLE_PYTHON) + +#include +#include +#include +#include +#include "Util/logger.h" +#include "Common/config.h" +#include "Common/MediaSource.h" +#include "Player/PlayerProxy.h" +#include "Rtsp/RtspSession.h" +#include "Http/HttpSession.h" + +namespace py = pybind11; + +namespace mediakit { + +class PythonInvoker : public std::enable_shared_from_this{ +public: + ~PythonInvoker(); + + static PythonInvoker& Instance(); + static void release(); + + void load(const std::string &module_name); + bool on_publish(BroadcastMediaPublishArgs) const; + bool on_play(BroadcastMediaPlayedArgs) const; + bool on_flow_report(BroadcastFlowReportArgs) const; + bool on_media_changed(BroadcastMediaChangedArgs) const; + bool on_player_proxy_failed(BroadcastPlayerProxyFailedArgs) const; + bool on_get_rtsp_realm(BroadcastOnGetRtspRealmArgs) const; + bool on_rtsp_auth(BroadcastOnRtspAuthArgs) const; + bool on_stream_not_found(BroadcastNotFoundStreamArgs) const; + bool on_record_mp4(BroadcastRecordMP4Args) const; + bool on_record_ts(BroadcastRecordTsArgs) const; + bool on_stream_none_reader(BroadcastStreamNoneReaderArgs) const; + bool on_send_rtp_stopped(BroadcastSendRtpStoppedArgs) const; + bool on_http_access(BroadcastHttpAccessArgs) const; + bool on_rtp_server_timeout(BroadcastRtpServerTimeoutArgs) const; + +private: + PythonInvoker(); + +private: + py::gil_scoped_release *_rel; + py::scoped_interpreter *_interpreter; + std::shared_ptr _logger; + py::module _module; + + // 程序退出 + py::function _on_exit; + // 推流鉴权 + py::function _on_publish; + // 播放鉴权 + py::function _on_play; + // 流量汇报接口 + py::function _on_flow_report; + // 配置文件热更新回调 + py::function _on_reload_config; + // 媒体注册注销 + py::function _on_media_changed; + // 拉流代理失败 + py::function _on_player_proxy_failed; + // rtsp播放是否开启专属鉴权 + py::function _on_get_rtsp_realm; + // rtsp播放或推流鉴权回调 + py::function _on_rtsp_auth; + // 播放一个不存在的流时触发 + py::function _on_stream_not_found; + // 生成mp4录制文件回调 + py::function _on_record_mp4; + // 生成hls ts/fmp4切片文件回调 + py::function _on_record_ts; + // 流无人观看事件 + py::function _on_stream_none_reader; + // rtp转发失败事件 + py::function _on_send_rtp_stopped; + // http访问鉴权事件 + py::function _on_http_access; + // rtp服务收流超时事件 + py::function _on_rtp_server_timeout; + // 创建Python muxer对象 + py::function _on_create_muxer; + + +}; + +} // namespace mediakit + +#endif +#endif // PYINVOKER_H \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7ce2eb88..2feae7e5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,14 +26,6 @@ file(GLOB MediaKit_SRC_LIST ${CMAKE_CURRENT_SOURCE_DIR}/*/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/*/*.h) -if(NOT ENABLE_SRT) - file(GLOB SRT_SRC_LIST - ${CMAKE_CURRENT_SOURCE_DIR}/Srt/*.c - ${CMAKE_CURRENT_SOURCE_DIR}/Srt/*.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/Srt/*.h) - list(REMOVE_ITEM MediaKit_SRC_LIST ${SRT_SRC_LIST}) -endif() - if(USE_SOLUTION_FOLDERS AND (NOT GROUP_BY_EXPLORER)) # 在 IDE 中对文件进行分组, 源文件和头文件分开 set_file_group("${CMAKE_CURRENT_SOURCE_DIR}" ${MediaKit_SRC_LIST}) @@ -67,8 +59,9 @@ update_cached_list(MK_LINK_LIBRARIES ZLMediaKit::MediaKit) if(ENABLE_CXX_API) # 保留目录结构 install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ - DESTINATION ${INSTALL_PATH_INCLUDE}/ZLMediaKit - REGEX ".*[.](md|cpp)$" EXCLUDE) + DESTINATION ${INSTALL_PATH_INCLUDE}/ZLMediaKit + FILES_MATCHING + PATTERN "*.h") install(TARGETS zlmediakit DESTINATION ${INSTALL_PATH_LIB}) endif () diff --git a/src/Codec/Transcode.cpp b/src/Codec/Transcode.cpp index ebe8770a..552f37d4 100644 --- a/src/Codec/Transcode.cpp +++ b/src/Codec/Transcode.cpp @@ -35,13 +35,8 @@ static string ffmpeg_err(int errnum) { return errbuf; } -std::shared_ptr alloc_av_packet() { - auto pkt = std::shared_ptr(av_packet_alloc(), [](AVPacket *pkt) { - av_packet_free(&pkt); - }); - pkt->data = NULL; // packet data will be allocated by the encoder - pkt->size = 0; - return pkt; +std::unique_ptr alloc_av_packet() { + return std::unique_ptr(av_packet_alloc(), [](AVPacket *pkt) { av_packet_free(&pkt); }); } ////////////////////////////////////////////////////////////////////////////////////////// @@ -165,7 +160,9 @@ void TaskManager::startThread(const string &name) { _thread.reset(new thread([this, name]() { onThreadRun(name); }), [](thread *ptr) { - ptr->join(); + if (ptr->joinable()) { + ptr->join(); + } delete ptr; }); } @@ -242,10 +239,6 @@ FFmpegFrame::FFmpegFrame(std::shared_ptr frame) { } FFmpegFrame::~FFmpegFrame() { - if (_data) { - delete[] _data; - _data = nullptr; - } } AVFrame *FFmpegFrame::get() const { @@ -253,9 +246,26 @@ AVFrame *FFmpegFrame::get() const { } void FFmpegFrame::fillPicture(AVPixelFormat target_format, int target_width, int target_height) { - assert(_data == nullptr); - _data = new char[av_image_get_buffer_size(target_format, target_width, target_height, 32)]; - av_image_fill_arrays(_frame->data, _frame->linesize, (uint8_t *) _data, target_format, target_width, target_height, 32); + auto buffer_size = av_image_get_buffer_size(target_format, target_width, target_height, 32); + _data = std::unique_ptr(new char[buffer_size]); + av_image_fill_arrays(_frame->data, _frame->linesize, (uint8_t *)_data.get(), target_format, target_width, target_height, 32); +} + +int FFmpegFrame::getChannels() const { + if (!_frame) return 0; +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + return _frame->ch_layout.nb_channels; +#else + return _frame->channels; +#endif +} + +// 资源池复用前调用 +void FFmpegFrame::reset() { + _data.reset(); + if (_frame) { + av_frame_unref(_frame.get()); // 清理AVFrame数据引用 + } } /////////////////////////////////////////////////////////////////////////// @@ -326,6 +336,7 @@ static inline const AVCodec *getCodecByName(const std::vector &code FFmpegDecoder::FFmpegDecoder(const Track::Ptr &track, int thread_num, const std::vector &codec_name) { setupFFmpeg(); + _frame_pool.setSize(AV_NUM_DATA_POINTERS); const AVCodec *codec = nullptr; const AVCodec *codec_default = nullptr; if (!codec_name.empty()) { @@ -421,22 +432,25 @@ FFmpegDecoder::FFmpegDecoder(const Track::Ptr &track, int thread_num, const std: _context->flags |= AV_CODEC_FLAG_LOW_DELAY; _context->flags2 |= AV_CODEC_FLAG2_FAST; if (track->getTrackType() == TrackVideo) { - auto video = static_pointer_cast(track); - _context->width = video->getVideoWidth(); - _context->height = video->getVideoHeight(); - InfoL << "decode video " << video->getCodecName() << " " << _context->width << "x" << _context->height; - } else { - auto audio = static_pointer_cast(track); - InfoL << "decode audio " << audio->getCodecName() << " " << audio->getAudioSampleRate() << "x" << audio->getAudioChannel(); - switch (track->getCodecId()) { - case CodecG711A: - case CodecG711U: { - _context->channels = audio->getAudioChannel(); - _context->sample_rate = audio->getAudioSampleRate(); - _context->channel_layout = av_get_default_channel_layout(_context->channels); - break; - } - default: break; + _context->width = static_pointer_cast(track)->getVideoWidth(); + _context->height = static_pointer_cast(track)->getVideoHeight(); + InfoL << "media source :" << _context->width << " X " << _context->height; + } + + switch (track->getCodecId()) { + case CodecG711A: + case CodecG711U: { + AudioTrack::Ptr audio = static_pointer_cast(track); + +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + av_channel_layout_default(&_context->ch_layout, audio->getAudioChannel()); +#else + _context->channels = audio->getAudioChannel(); + _context->channel_layout = av_get_default_channel_layout(_context->channels); +#endif + + _context->sample_rate = audio->getAudioSampleRate(); + break; } } AVDictionary *dict = nullptr; @@ -490,7 +504,7 @@ FFmpegDecoder::~FFmpegDecoder() { void FFmpegDecoder::flush() { while (true) { - auto out_frame = std::make_shared(); + auto out_frame = _frame_pool.obtain2(); auto ret = avcodec_receive_frame(_context.get(), out_frame->get()); if (ret == AVERROR(EAGAIN)) { avcodec_send_packet(_context.get(), nullptr); @@ -533,7 +547,7 @@ bool FFmpegDecoder::inputFrame(const Frame::Ptr &frame, bool live, bool async, b inputFrame_l(frame_cache, live, enable_merge); // 此处模拟解码太慢导致的主动丢帧 [AUTO-TRANSLATED:fc8bea8a] // Here simulates decoding too slow, resulting in active frame dropping - //usleep(100 * 1000); + // usleep(100 * 1000); }); } @@ -541,7 +555,7 @@ bool FFmpegDecoder::decodeFrame(const char *data, size_t size, uint64_t dts, uin TimeTicker2(30, TraceL); auto pkt = alloc_av_packet(); - pkt->data = (uint8_t *) data; + pkt->data = (uint8_t *)data; pkt->size = size; pkt->dts = dts; pkt->pts = pts; @@ -557,8 +571,8 @@ bool FFmpegDecoder::decodeFrame(const char *data, size_t size, uint64_t dts, uin return false; } - while (true) { - auto out_frame = std::make_shared(); + for (;;) { + auto out_frame = _frame_pool.obtain2(); ret = avcodec_receive_frame(_context.get(), out_frame->get()); if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { break; @@ -589,7 +603,6 @@ void FFmpegDecoder::onDecode(const FFmpegFrame::Ptr &frame) { } //////////////////////////////////////////////////////////////////////////////////////////////////////////// - FFmpegAudioFifo::~FFmpegAudioFifo() { if (_fifo) { av_audio_fifo_free(_fifo); @@ -603,15 +616,20 @@ int FFmpegAudioFifo::size() const { bool FFmpegAudioFifo::Write(const AVFrame *frame) { _format = (AVSampleFormat)frame->format; +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + int channels = frame->ch_layout.nb_channels; +#else + int channels = frame->channels; +#endif if (!_fifo) { - _fifo = av_audio_fifo_alloc(_format, frame->channels, frame->nb_samples); + _fifo = av_audio_fifo_alloc(_format, channels, frame->nb_samples); if (!_fifo) { - WarnL << "av_audio_fifo_alloc " << frame->channels << "x" << frame->nb_samples << "error"; + WarnL << "av_audio_fifo_alloc " << channels << "x" << frame->nb_samples << "error"; return false; } } - _channels = frame->channels; + _channels = channels; if (_samplerate != frame->sample_rate) { _samplerate = frame->sample_rate; // 假定传入frame的时间戳是以ms为单位的 @@ -643,7 +661,12 @@ bool FFmpegAudioFifo::Read(AVFrame *frame, int sample_size) { av_samples_get_buffer_size(frame->linesize, _channels, sample_size, _format, 0); frame->nb_samples = sample_size; frame->format = _format; +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + av_channel_layout_default(&frame->ch_layout, _channels); +#else frame->channel_layout = av_get_default_channel_layout(_channels); + frame->channels = _channels; +#endif frame->sample_rate = _samplerate; if (fabs(_tsp) > DBL_EPSILON) { frame->pts = _tsp; @@ -665,41 +688,75 @@ bool FFmpegAudioFifo::Read(AVFrame *frame, int sample_size) { } //////////////////////////////////////////////////////////////////////////////////////////////////////////// - +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 +FFmpegSwr::FFmpegSwr(AVSampleFormat output, AVChannelLayout *ch_layout, int samplerate) { + _target_format = output; + av_channel_layout_copy(&_target_ch_layout, ch_layout); + _target_samplerate = samplerate; +} +#else FFmpegSwr::FFmpegSwr(AVSampleFormat output, int channel, int channel_layout, int samplerate) { _target_format = output; _target_channels = channel; _target_channel_layout = channel_layout; _target_samplerate = samplerate; + + _swr_frame_pool.setSize(AV_NUM_DATA_POINTERS); } +#endif FFmpegSwr::~FFmpegSwr() { if (_ctx) { swr_free(&_ctx); } +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + av_channel_layout_uninit(&_target_ch_layout); +#endif } FFmpegFrame::Ptr FFmpegSwr::inputFrame(const FFmpegFrame::Ptr &frame) { if (frame->get()->format == _target_format && - frame->get()->channels == _target_channels && - frame->get()->channel_layout == (uint64_t)_target_channel_layout && + +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + !av_channel_layout_compare(&(frame->get()->ch_layout), &_target_ch_layout) && +#else + frame->get()->channels == _target_channels && frame->get()->channel_layout == (uint64_t)_target_channel_layout && +#endif + frame->get()->sample_rate == _target_samplerate) { // 不转格式 [AUTO-TRANSLATED:31dc6ae1] // Do not convert format return frame; } if (!_ctx) { + +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + _ctx = swr_alloc(); + swr_alloc_set_opts2(&_ctx, + &_target_ch_layout, _target_format, _target_samplerate, + &frame->get()->ch_layout, (AVSampleFormat)frame->get()->format, frame->get()->sample_rate, + 0, nullptr); +#else _ctx = swr_alloc_set_opts(nullptr, _target_channel_layout, _target_format, _target_samplerate, frame->get()->channel_layout, (AVSampleFormat) frame->get()->format, frame->get()->sample_rate, 0, nullptr); +#endif + InfoL << "swr_alloc_set_opts:" << av_get_sample_fmt_name((enum AVSampleFormat) frame->get()->format) << " -> " << av_get_sample_fmt_name(_target_format); } if (_ctx) { - auto out = std::make_shared(); + auto out = _swr_frame_pool.obtain2(); out->get()->format = _target_format; + +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + out->get()->ch_layout = _target_ch_layout; + av_channel_layout_copy(&(out->get()->ch_layout), &_target_ch_layout); +#else out->get()->channel_layout = _target_channel_layout; out->get()->channels = _target_channels; +#endif + out->get()->sample_rate = _target_samplerate; out->get()->pkt_dts = frame->get()->pkt_dts; out->get()->pts = frame->get()->pts; @@ -721,6 +778,8 @@ FFmpegSws::FFmpegSws(AVPixelFormat output, int width, int height) { _target_format = output; _target_width = width; _target_height = height; + + _sws_frame_pool.setSize(AV_NUM_DATA_POINTERS); } FFmpegSws::~FFmpegSws() { @@ -751,7 +810,7 @@ FFmpegFrame::Ptr FFmpegSws::inputFrame(const FFmpegFrame::Ptr &frame, int &ret, // Do not convert format return frame; } - if (_ctx && (_src_width != frame->get()->width || _src_height != frame->get()->height || _src_format != (enum AVPixelFormat) frame->get()->format)) { + if (_ctx && (_src_width != frame->get()->width || _src_height != frame->get()->height || _src_format != (enum AVPixelFormat)frame->get()->format)) { // 输入分辨率发生变化了 [AUTO-TRANSLATED:0e4ea2e8] // Input resolution has changed sws_freeContext(_ctx); @@ -765,7 +824,8 @@ FFmpegFrame::Ptr FFmpegSws::inputFrame(const FFmpegFrame::Ptr &frame, int &ret, InfoL << "sws_getContext:" << av_get_pix_fmt_name((enum AVPixelFormat) frame->get()->format) << " -> " << av_get_pix_fmt_name(_target_format); } if (_ctx) { - auto out = std::make_shared(); + auto out = _sws_frame_pool.obtain2(); + out->reset(); // 清理旧数据和帧引用 if (!out->get()->data[0]) { if (data) { av_image_fill_arrays(out->get()->data, out->get()->linesize, data, _target_format, target_width, target_height, 32); @@ -788,7 +848,129 @@ FFmpegFrame::Ptr FFmpegSws::inputFrame(const FFmpegFrame::Ptr &frame, int &ret, return nullptr; } -/////////////////////////////////////////////////////////////////////////////////////////////////////// +std::tuple FFmpegUtils::saveFrame(const FFmpegFrame::Ptr &frame, const char *filename, AVPixelFormat fmt, int w, int h, const char *font_path) { + std::shared_ptr _filter_graph; + AVFilterContext *buffersrc_ctx = nullptr; + AVFilterContext *buffersink_ctx = nullptr; + const AVFilter *buffersrc = nullptr; + const AVFilter *buffersink = nullptr; + // kServerName + const string mark = "ZLMediaKit"; + char drawtext_args1[512]; + _StrPrinter ss; + + std::unique_ptr tmp_save_file_jpg(File::create_file(filename, "wb"), [](FILE *fp) { + if (fp) { + fclose(fp); + } + }); + + if (!tmp_save_file_jpg) { + ss << "Could not open the file " << filename; + DebugL << ss; + return make_tuple(false, ss.data()); + } + + std::string fontfile(""); + if (font_path && File::fileExist(font_path)) { + fontfile = font_path; + } else { + // Fallback to common default + fontfile = exeDir() + "/DejaVuSans.ttf"; + } + + snprintf(drawtext_args1, sizeof(drawtext_args1), "text='%s':fontfile='%s':fontcolor=white@0.1:fontsize=h/50:x=w*0.02:y=h-th-h*0.02", mark.data(), fontfile.c_str()); + + const AVCodec *jpeg_codec = avcodec_find_encoder(fmt == AV_PIX_FMT_YUVJ420P ? AV_CODEC_ID_MJPEG : AV_CODEC_ID_PNG); + std::unique_ptr jpeg_codec_ctx( + jpeg_codec ? avcodec_alloc_context3(jpeg_codec) : nullptr, [](AVCodecContext *ctx) { avcodec_free_context(&ctx); }); + + if (!jpeg_codec_ctx) { + ss << "Could not allocate JPEG/PNG codec context"; + DebugL << ss; + return make_tuple(false, ss.data()); + } + + jpeg_codec_ctx->width = (w > 0 && w < 8192) ? w : frame->get()->width; + jpeg_codec_ctx->height = (h > 0 && h < 4320) ? h : frame->get()->height; + jpeg_codec_ctx->pix_fmt = fmt; + jpeg_codec_ctx->time_base = { 1, 1 }; + + auto ret = avcodec_open2(jpeg_codec_ctx.get(), jpeg_codec, NULL); + if (ret < 0) { + ss << "Could not open JPEG/PNG codec, " << ffmpeg_err(ret); + DebugL << ss; + return make_tuple(false, ss.data()); + } + + FFmpegSws sws(fmt, jpeg_codec_ctx->width, jpeg_codec_ctx->height); + auto new_frame = sws.inputFrame(frame); + if (!new_frame) { + ss << "Could not scale the frame"; + DebugL << ss; + return make_tuple(false, ss.data()); + } + + _filter_graph.reset(avfilter_graph_alloc(), [](AVFilterGraph *ctx) { avfilter_graph_free(&ctx); }); + if (!_filter_graph) { + ss << "avfilter_graph_alloc failed"; + DebugL << ss; + return make_tuple(false, ss.data()); + } + + char args[512]; + snprintf( + args, sizeof(args), "video_size=%dx%d:pix_fmt=%d:time_base=%d/%d:pixel_aspect=%d/%d", jpeg_codec_ctx->width, jpeg_codec_ctx->height, + jpeg_codec_ctx->pix_fmt, jpeg_codec_ctx->time_base.num, jpeg_codec_ctx->time_base.den, jpeg_codec_ctx->sample_aspect_ratio.num, + jpeg_codec_ctx->sample_aspect_ratio.den); + + buffersrc = avfilter_get_by_name("buffer"); + + if ((ret = avfilter_graph_create_filter(&buffersrc_ctx, buffersrc, "in", args, NULL, _filter_graph.get())) < 0) { + ss << "avfilter_graph_create_filter buffersrc failed: " << ret << " " << ffmpeg_err(ret); + DebugL << ss; + return make_tuple(false, ss.data()); + } + + buffersink = avfilter_get_by_name("buffersink"); + if ((ret = avfilter_graph_create_filter(&buffersink_ctx, buffersink, "out", NULL, NULL, _filter_graph.get())) < 0) { + ss << "avfilter_graph_create_filter buffersink failed: " << ret << " " << ffmpeg_err(ret); + return make_tuple(false, ss.data()); + } + + AVFilterContext *drawtext_ctx1 = nullptr; + + const AVFilter *drawtext_filter = avfilter_get_by_name("drawtext"); + if ((ret = avfilter_graph_create_filter(&drawtext_ctx1, drawtext_filter, "drawtext", drawtext_args1, NULL, _filter_graph.get())) < 0) { + ss << "avfilter_graph_create_filter drawtext_filter failed: " << ret << " " << ffmpeg_err(ret); + return make_tuple(false, ss.data()); + } + + if ((ret = avfilter_link(buffersrc_ctx, 0, drawtext_ctx1, 0) < 0 || avfilter_link(drawtext_ctx1, 0, buffersink_ctx, 0))< 0) { + ss << "avfilter_link: " << ret << " " << ffmpeg_err(ret); + return make_tuple(false, ss.data()); + } + + if ((ret = avfilter_graph_config(_filter_graph.get(), NULL)) < 0) { + ss << "avfilter_graph_config failed: " << ret << " " << ffmpeg_err(ret); + return make_tuple(false, ss.data()); + } + + if ((ret = av_buffersrc_add_frame_flags(buffersrc_ctx, new_frame->get(), 0)) < 0) { + ss << "av_buffersink_get_frame failed: " << ret << " " << ffmpeg_err(ret); + return make_tuple(false, ss.data()); + } + + auto pkt = alloc_av_packet(); + while (av_buffersink_get_frame(buffersink_ctx, new_frame->get()) >= 0) { + if (avcodec_send_frame(jpeg_codec_ctx.get(), new_frame->get()) == 0) { + while (avcodec_receive_packet(jpeg_codec_ctx.get(), pkt.get()) == 0) { + fwrite(pkt.get()->data, pkt.get()->size, 1, tmp_save_file_jpg.get()); + } + } + } + return make_tuple(true, ""); +} void setupContext(AVCodecContext *_context, int bitrate) { //保存AVFrame的引用 @@ -941,17 +1123,20 @@ bool FFmpegEncoder::openAudioCodec(int samplerate, int channel, int bitrate, con _context->sample_fmt = codec->sample_fmts[0]; _context->sample_rate = samplerate; +#if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + av_channel_layout_default(&_context->ch_layout, channel); + _swr.reset(new FFmpegSwr(_context->sample_fmt, &_context->ch_layout, _context->sample_rate)); +#else _context->channels = channel; _context->channel_layout = av_get_default_channel_layout(_context->channels); + _swr.reset(new FFmpegSwr(_context->sample_fmt, _context->channels, _context->channel_layout, _context->sample_rate)); +#endif if (getCodecId() == CodecOpus) _context->compression_level = 1; //_sample_bytes = av_get_bytes_per_sample(_context->sample_fmt) * _context->channels; - _swr.reset( - new FFmpegSwr(_context->sample_fmt, _context->channels, _context->channel_layout, _context->sample_rate)); - - InfoL << "openAudioCodec " << codec->name << " " << _context->sample_rate << "x" << _context->channels; + InfoL << "openAudioCodec " << codec->name << " " << _context->sample_rate << "x" << channel; return avcodec_open2(_context.get(), codec, &_dict) >= 0; } return false; diff --git a/src/Codec/Transcode.h b/src/Codec/Transcode.h index b55130a8..de857fa0 100644 --- a/src/Codec/Transcode.h +++ b/src/Codec/Transcode.h @@ -26,10 +26,16 @@ extern "C" { #include "libswresample/swresample.h" #include "libavutil/audio_fifo.h" #include "libavutil/imgutils.h" +#include "libavutil/frame.h" +#include "libavfilter/avfilter.h" +#include "libavfilter/buffersink.h" +#include "libavfilter/buffersrc.h" #ifdef __cplusplus } #endif +#define FF_CODEC_VER_7_1 AV_VERSION_INT(61, 0, 0) + namespace mediakit { class FFmpegFrame { @@ -41,9 +47,11 @@ public: AVFrame *get() const; void fillPicture(AVPixelFormat target_format, int target_width, int target_height); + int getChannels() const; + void reset(); private: - char *_data = nullptr; + std::unique_ptr _data; std::shared_ptr _frame; }; @@ -51,16 +59,29 @@ class FFmpegSwr { public: using Ptr = std::shared_ptr; +# if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + FFmpegSwr(AVSampleFormat output, AVChannelLayout *ch_layout, int samplerate); +#else FFmpegSwr(AVSampleFormat output, int channel, int channel_layout, int samplerate); +#endif + ~FFmpegSwr(); FFmpegFrame::Ptr inputFrame(const FFmpegFrame::Ptr &frame); private: + +# if LIBAVCODEC_VERSION_INT >= FF_CODEC_VER_7_1 + AVChannelLayout _target_ch_layout; +#else int _target_channels; int _target_channel_layout; +#endif + int _target_samplerate; AVSampleFormat _target_format; SwrContext *_ctx = nullptr; + + toolkit::ResourcePool _swr_frame_pool; }; class FFmpegAudioFifo { @@ -138,6 +159,7 @@ private: onDec _cb; std::shared_ptr _context; FrameMerger _merger{FrameMerger::h264_prefix}; + toolkit::ResourcePool _frame_pool; }; class FFmpegSws { @@ -160,6 +182,21 @@ private: SwsContext *_ctx = nullptr; AVPixelFormat _src_format = AV_PIX_FMT_NONE; AVPixelFormat _target_format = AV_PIX_FMT_NONE; + toolkit::ResourcePool _sws_frame_pool; +}; + +class FFmpegUtils { +public: + /** + * 保持图片为jpeg或png + * @param frame 解码后的帧 + * @param filename 保存文件路径 + * @param fmt jpg:AV_PIX_FMT_YUVJ420P,PNG:AV_PIX_FMT_RGB24 + * @param w h (可选)裁剪的图片大小,默认和输入源一致 + * @param font_path (可选), default DejaVuSans.ttf + * @return + */ + static std::tuple saveFrame(const FFmpegFrame::Ptr &frame, const char *filename, AVPixelFormat fmt = AV_PIX_FMT_YUVJ420P, int w = 0, int h = 0, const char *font_path = nullptr); }; class FFmpegEncoder : public TaskManager, public CodecInfo { diff --git a/src/Common/MediaSink.cpp b/src/Common/MediaSink.cpp index 52992b05..cec1841c 100644 --- a/src/Common/MediaSink.cpp +++ b/src/Common/MediaSink.cpp @@ -176,7 +176,9 @@ void MediaSink::checkTrackIfReady() { } void MediaSink::addTrackCompleted() { - setMaxTrackCount(_track_map.size()); + if (!_track_map.empty()) { + setMaxTrackCount(_track_map.size()); + } } void MediaSink::setMaxTrackCount(size_t i) { diff --git a/src/Common/MediaSource.cpp b/src/Common/MediaSource.cpp index 14c18d13..f1f210bd 100644 --- a/src/Common/MediaSource.cpp +++ b/src/Common/MediaSource.cpp @@ -110,13 +110,20 @@ std::shared_ptr MediaSource::getOwnership() { }); } -int MediaSource::getBytesSpeed(TrackType type){ +size_t MediaSource::getBytesSpeed(TrackType type) { if(type == TrackInvalid || type == TrackMax){ return _speed[TrackVideo].getSpeed() + _speed[TrackAudio].getSpeed(); } return _speed[type].getSpeed(); } +size_t MediaSource::getTotalBytes(TrackType type) { + if (type == TrackInvalid || type == TrackMax) { + return _speed[TrackVideo].getTotalBytes() + _speed[TrackAudio].getTotalBytes(); + } + return _speed[type].getTotalBytes(); +} + uint64_t MediaSource::getAliveSecond() const { // 使用Ticker对象获取存活时间的目的是防止修改系统时间导致回退 [AUTO-TRANSLATED:68474061] // The purpose of using the Ticker object to obtain the survival time is to prevent the modification of the system time from causing a rollback @@ -125,10 +132,10 @@ uint64_t MediaSource::getAliveSecond() const { vector MediaSource::getTracks(bool ready) const { auto listener = _listener.lock(); - if(!listener){ + if (!listener) { return vector(); } - return listener->getMediaTracks(const_cast(*this), ready); + return listener->getMuxer(const_cast(*this))->getTracks(ready); } void MediaSource::setListener(const std::weak_ptr &listener){ @@ -270,7 +277,7 @@ bool MediaSource::setupRecord(Recorder::type type, bool start, const string &cus WarnL << "未设置MediaSource的事件监听者,setupRecord失败:" << getUrl(); return false; } - return listener->setupRecord(*this, type, start, custom_path, max_second); + return listener->getMuxer(const_cast(*this))->setupRecord(type, start, custom_path, max_second); } bool MediaSource::isRecording(Recorder::type type){ @@ -278,7 +285,7 @@ bool MediaSource::isRecording(Recorder::type type){ if(!listener){ return false; } - return listener->isRecording(*this, type); + return listener->getMuxer(const_cast(*this))->isRecording(type); } void MediaSource::startSendRtp(const MediaSourceEvent::SendRtpArgs &args, const std::function cb) { @@ -287,7 +294,7 @@ void MediaSource::startSendRtp(const MediaSourceEvent::SendRtpArgs &args, const cb(0, SockException(Err_other, "尚未设置事件监听器")); return; } - return listener->startSendRtp(*this, args, cb); + return listener->getMuxer(const_cast(*this))->startSendRtp(args, cb); } bool MediaSource::stopSendRtp(const string &ssrc) { @@ -295,7 +302,7 @@ bool MediaSource::stopSendRtp(const string &ssrc) { if (!listener) { return false; } - return listener->stopSendRtp(*this, ssrc); + return listener->getMuxer(const_cast(*this))->stopSendRtp(ssrc); } template @@ -660,6 +667,14 @@ void MediaSourceEvent::onReaderChanged(MediaSource &sender, int size){ bool is_mp4_vod = sender.getMediaTuple().app == record_app; weak_ptr weak_sender = sender.shared_from_this(); + EventPoller::Ptr specified_poller; + try { + specified_poller = this->getOwnerPoller(sender); + } + catch (std::exception &ex) { + // 尝试获取 OwnerPoller,没有实现则使用默认 nullptr + // WarnL << ex.what(); + } _async_close_timer = std::make_shared(stream_none_reader_delay / 1000.0f, [weak_sender, is_mp4_vod]() { auto strong_sender = weak_sender.lock(); if (!strong_sender) { @@ -675,25 +690,24 @@ void MediaSourceEvent::onReaderChanged(MediaSource &sender, int size){ } if (!is_mp4_vod) { + // 直播时触发无人观看事件,让开发者自行选择是否关闭 [AUTO-TRANSLATED:c6c75eaa] + // When live streaming, trigger the no-viewer event, allowing developers to choose whether to close it. + NOTICE_EMIT(BroadcastStreamNoneReaderArgs, Broadcast::kBroadcastStreamNoneReader, *strong_sender); auto muxer = strong_sender->getMuxer(); if (muxer && muxer->getOption().auto_close) { // 此流被标记为无人观看自动关闭流 [AUTO-TRANSLATED:64a0dac3] // This stream is marked as an automatically closed stream with no viewers. - WarnL << "Auto cloe stream when none reader: " << strong_sender->getUrl(); - strong_sender->close(false); - } else { - // 直播时触发无人观看事件,让开发者自行选择是否关闭 [AUTO-TRANSLATED:c6c75eaa] - // When live streaming, trigger the no-viewer event, allowing developers to choose whether to close it. - NOTICE_EMIT(BroadcastStreamNoneReaderArgs, Broadcast::kBroadcastStreamNoneReader, *strong_sender); + WarnL << "Auto close stream when none reader: " << strong_sender->getUrl(); + strong_sender->getOwnerPoller()->async([strong_sender]() { strong_sender->close(false); }); } } else { // 这个是mp4点播,我们自动关闭 [AUTO-TRANSLATED:8a7b9a90] // This is an mp4 on-demand, we automatically close it. WarnL << "MP4点播无人观看,自动关闭:" << strong_sender->getUrl(); - strong_sender->close(false); + strong_sender->getOwnerPoller()->async([strong_sender]() { strong_sender->close(false); }); } return false; - }, nullptr); + }, specified_poller); } string MediaSourceEvent::getOriginUrl(MediaSource &sender) const { @@ -816,46 +830,6 @@ std::shared_ptr MediaSourceEventInterceptor::getRtpProcess(MediaSour return listener->getRtpProcess(sender); } -bool MediaSourceEventInterceptor::setupRecord(MediaSource &sender, Recorder::type type, bool start, const string &custom_path, size_t max_second) { - auto listener = _listener.lock(); - if (!listener) { - return MediaSourceEvent::setupRecord(sender, type, start, custom_path, max_second); - } - return listener->setupRecord(sender, type, start, custom_path, max_second); -} - -bool MediaSourceEventInterceptor::isRecording(MediaSource &sender, Recorder::type type) { - auto listener = _listener.lock(); - if (!listener) { - return MediaSourceEvent::isRecording(sender, type); - } - return listener->isRecording(sender, type); -} - -vector MediaSourceEventInterceptor::getMediaTracks(MediaSource &sender, bool trackReady) const { - auto listener = _listener.lock(); - if (!listener) { - return MediaSourceEvent::getMediaTracks(sender, trackReady); - } - return listener->getMediaTracks(sender, trackReady); -} - -void MediaSourceEventInterceptor::startSendRtp(MediaSource &sender, const MediaSourceEvent::SendRtpArgs &args, const std::function cb) { - auto listener = _listener.lock(); - if (!listener) { - return MediaSourceEvent::startSendRtp(sender, args, cb); - } - listener->startSendRtp(sender, args, cb); -} - -bool MediaSourceEventInterceptor::stopSendRtp(MediaSource &sender, const string &ssrc) { - auto listener = _listener.lock(); - if (!listener) { - return MediaSourceEvent::stopSendRtp(sender, ssrc); - } - return listener->stopSendRtp(sender, ssrc); -} - void MediaSourceEventInterceptor::setDelegate(const std::weak_ptr &listener) { if (listener.lock().get() == this) { throw std::invalid_argument("can not set self as a delegate"); diff --git a/src/Common/MediaSource.h b/src/Common/MediaSource.h index c771a11a..16675b5a 100644 --- a/src/Common/MediaSource.h +++ b/src/Common/MediaSource.h @@ -94,17 +94,6 @@ public: // Get the current thread, this function is generally forced to overload virtual toolkit::EventPoller::Ptr getOwnerPoller(MediaSource &sender) { throw NotImplemented(toolkit::demangle(typeid(*this).name()) + "::getOwnerPoller not implemented"); } - // //////////////////////仅供MultiMediaSourceMuxer对象继承//////////////////////// [AUTO-TRANSLATED:6e810d1f] - // //////////////////////Only for MultiMediaSourceMuxer object inheritance//////////////////////// - // 开启或关闭录制 [AUTO-TRANSLATED:3817e390] - // Start or stop recording - virtual bool setupRecord(MediaSource &sender, Recorder::type type, bool start, const std::string &custom_path, size_t max_second) { return false; }; - // 获取录制状态 [AUTO-TRANSLATED:a0499880] - // Get recording status - virtual bool isRecording(MediaSource &sender, Recorder::type type) { return false; } - // 获取所有track相关信息 [AUTO-TRANSLATED:2141be42] - // Get all track related information - virtual std::vector getMediaTracks(MediaSource &sender, bool trackReady = true) const { return std::vector(); }; // 获取MultiMediaSourceMuxer对象 [AUTO-TRANSLATED:2de96d44] // Get MultiMediaSourceMuxer object virtual std::shared_ptr getMuxer(MediaSource &sender) const { return nullptr; } @@ -125,7 +114,7 @@ public: kUdpActive = 1, // udp主动模式,主动发送数据给对方 kTcpPassive = 2, // tcp被动模式,tcp服务器,等待对方连接并回复rtp kUdpPassive = 3, // udp被动方式,等待对方发送nat打洞包,然后回复rtp至打洞包源地址 - kVoiceTalk = 4, // 语音对讲模式,对方必须想推流上来,通过他的推流链路再回复rtp数据 + kVoiceTalk = 4, // 语音对讲模式,对方必须先推流上来,通过他的推流链路再回复rtp数据 }; // rtp类型 [AUTO-TRANSLATED:acca40ab] @@ -175,14 +164,10 @@ public: std::string recv_stream_app; std::string recv_stream_vhost; - }; - // 开始发送ps-rtp [AUTO-TRANSLATED:a51796fa] - // Start sending ps-rtp - virtual void startSendRtp(MediaSource &sender, const SendRtpArgs &args, const std::function cb) { cb(0, toolkit::SockException(toolkit::Err_other, "not implemented"));}; - // 停止发送ps-rtp [AUTO-TRANSLATED:952d2b35] - // Stop sending ps-rtp - virtual bool stopSendRtp(MediaSource &sender, const std::string &ssrc) {return false; } + // rtp tcp模式发送时busy时, origin 接收限流, 默认不启用 + bool enable_origin_recv_limit = false; + }; private: toolkit::Timer::Ptr _async_close_timer; @@ -306,6 +291,38 @@ public: // Maximum number of tracks size_t max_track = 2; +#define OPT_VALUE(XX) \ + XX(modify_stamp) \ + XX(enable_audio) \ + XX(add_mute_audio) \ + XX(auto_close) \ + XX(continue_push_ms) \ + XX(paced_sender_ms) \ + \ + XX(enable_hls) \ + XX(enable_hls_fmp4) \ + XX(enable_mp4) \ + XX(enable_rtsp) \ + XX(enable_rtmp) \ + XX(enable_ts) \ + XX(enable_fmp4) \ + XX(enable_rtc) \ + XX(audio_transcode) \ + XX(rtc_demand) \ + XX(hls_demand) \ + XX(rtsp_demand) \ + XX(rtmp_demand) \ + XX(ts_demand) \ + XX(fmp4_demand) \ + \ + XX(mp4_max_second) \ + XX(mp4_as_player) \ + XX(mp4_save_path) \ + \ + XX(hls_save_path) \ + XX(stream_replace) \ + XX(max_track) + template ProtocolOption(const MAP &allArgs) : ProtocolOption() { load(allArgs); @@ -313,38 +330,18 @@ public: template void load(const MAP &allArgs) { -#define GET_OPT_VALUE(key) getArgsValue(allArgs, #key, key) - GET_OPT_VALUE(modify_stamp); - GET_OPT_VALUE(enable_audio); - GET_OPT_VALUE(add_mute_audio); - GET_OPT_VALUE(auto_close); - GET_OPT_VALUE(continue_push_ms); - GET_OPT_VALUE(paced_sender_ms); +#define GET(key) getArgsValue(allArgs, #key, key); + OPT_VALUE(GET) +#undef GET + } - GET_OPT_VALUE(enable_hls); - GET_OPT_VALUE(enable_hls_fmp4); - GET_OPT_VALUE(enable_mp4); - GET_OPT_VALUE(enable_rtsp); - GET_OPT_VALUE(enable_rtmp); - GET_OPT_VALUE(enable_ts); - GET_OPT_VALUE(enable_fmp4); - GET_OPT_VALUE(enable_rtc); - GET_OPT_VALUE(audio_transcode); - GET_OPT_VALUE(rtc_demand); - - GET_OPT_VALUE(hls_demand); - GET_OPT_VALUE(rtsp_demand); - GET_OPT_VALUE(rtmp_demand); - GET_OPT_VALUE(ts_demand); - GET_OPT_VALUE(fmp4_demand); - - GET_OPT_VALUE(mp4_max_second); - GET_OPT_VALUE(mp4_as_player); - GET_OPT_VALUE(mp4_save_path); - - GET_OPT_VALUE(hls_save_path); - GET_OPT_VALUE(stream_replace); - GET_OPT_VALUE(max_track); + template + MAP as() { + MAP ret; +#define SET(key) ret[#key] = key; + OPT_VALUE(SET) +#undef SET + return ret; } }; @@ -366,11 +363,6 @@ public: int totalReaderCount(MediaSource &sender) override; void onReaderChanged(MediaSource &sender, int size) override; void onRegist(MediaSource &sender, bool regist) override; - bool setupRecord(MediaSource &sender, Recorder::type type, bool start, const std::string &custom_path, size_t max_second) override; - bool isRecording(MediaSource &sender, Recorder::type type) override; - std::vector getMediaTracks(MediaSource &sender, bool trackReady = true) const override; - void startSendRtp(MediaSource &sender, const SendRtpArgs &args, const std::function cb) override; - bool stopSendRtp(MediaSource &sender, const std::string &ssrc) override; float getLossRate(MediaSource &sender, TrackType type) override; toolkit::EventPoller::Ptr getOwnerPoller(MediaSource &sender) override; std::shared_ptr getMuxer(MediaSource &sender) const override; @@ -395,6 +387,7 @@ public: public: uint16_t port = 0; + std::string protocol; std::string full_url; std::string schema; std::string host; @@ -448,7 +441,9 @@ public: // 获取数据速率,单位bytes/s [AUTO-TRANSLATED:c70465c1] // Get data rate, unit bytes/s - int getBytesSpeed(TrackType type = TrackInvalid); + size_t getBytesSpeed(TrackType type = TrackInvalid); + size_t getTotalBytes(TrackType type = TrackInvalid); + // 获取流创建GMT unix时间戳,单位秒 [AUTO-TRANSLATED:0bbe145e] // Get the stream creation GMT unix timestamp, unit seconds uint64_t getCreateStamp() const { return _create_stamp; } diff --git a/src/Common/MultiMediaSourceMuxer.cpp b/src/Common/MultiMediaSourceMuxer.cpp index a15c5b89..97430081 100644 --- a/src/Common/MultiMediaSourceMuxer.cpp +++ b/src/Common/MultiMediaSourceMuxer.cpp @@ -11,6 +11,7 @@ #include #include "Common/config.h" #include "MultiMediaSourceMuxer.h" +#include "Thread/WorkThreadPool.h" #include "Rtp/RtpSender.h" #include "Record/HlsRecorder.h" #include "Record/HlsMediaSource.h" @@ -71,32 +72,38 @@ public: setCurrentStamp(frame->dts()); resetTimer(EventPoller::getCurrentPoller()); } - - _cache.emplace_back(frame->dts() + _cache_ms, Frame::getCacheAbleFrame(frame)); + auto &last_dts = _last_dts[frame->getTrackType()]; + if (last_dts > frame->dts()) { + // 时间戳回退了,点播流? + WarnL << "Dts decrease: " << last_dts << "->" << frame->dts() << ", flush all paced sender cache: " << _cache.size(); + flushCache(frame->dts()); + } + _cache.emplace(frame->dts(), Frame::getCacheAbleFrame(frame)); + last_dts = frame->dts(); return true; } private: void onTick() { std::lock_guard lck(_mtx); - auto dst = _cache.empty() ? 0 : _cache.back().first; + auto max_dts = _cache.empty() ? 0 : _cache.rbegin()->first; while (!_cache.empty()) { - auto &front = _cache.front(); - if (getCurrentStamp() < front.first) { + auto front = _cache.begin(); + if (getCurrentStamp() < front->first + _cache_ms) { // 还没到消费时间 [AUTO-TRANSLATED:09fb4c3d] // Not yet time to consume break; } // 时间到了,该消费frame了 [AUTO-TRANSLATED:2f007931] // Time is up, it's time to consume the frame - _cb(front.second); - _cache.pop_front(); + _cb(front->second); + _cache.erase(front); } - if (_cache.empty() && dst) { + if (_cache.empty() && max_dts) { // 消费太快,需要增加缓存大小 [AUTO-TRANSLATED:c05bfbcd] // Consumption is too fast, need to increase cache size - setCurrentStamp(dst); + setCurrentStamp(max_dts); _cache_ms += kMinCacheMS; } @@ -104,15 +111,20 @@ private: // Consumption is too slow, need to force flush data if (_cache.size() > 25 * 5) { WarnL << "Flush frame paced sender cache: " << _cache.size(); - while (!_cache.empty()) { - auto &front = _cache.front(); - _cb(front.second); - _cache.pop_front(); - } - setCurrentStamp(dst); + flushCache(max_dts); } } + void flushCache(uint64_t dts) { + while (!_cache.empty()) { + auto front = _cache.begin(); + _cb(front->second); + _cache.erase(front); + } + setCurrentStamp(dts); + _cache_ms = kMinCacheMS; + } + uint64_t getCurrentStamp() { return _ticker.elapsedTime() + _stamp_offset; } void setCurrentStamp(uint64_t stamp) { @@ -124,15 +136,16 @@ private: uint32_t _paced_sender_ms; uint32_t _cache_ms = kMinCacheMS; uint64_t _stamp_offset = 0; + uint64_t _last_dts[2] = {0, 0}; OnFrame _cb; Ticker _ticker; Timer::Ptr _timer; std::recursive_mutex _mtx; - std::list> _cache; + std::multimap _cache; }; -std::shared_ptr MultiMediaSourceMuxer::makeRecorder(MediaSource &sender, Recorder::type type) { - auto recorder = Recorder::createRecorder(type, sender.getMediaTuple(), _option); +std::shared_ptr MultiMediaSourceMuxer::makeRecorder(Recorder::type type) { + auto recorder = Recorder::createRecorder(type, getMediaTuple(), _option); for (auto &track : getTracks()) { recorder->addTrack(track); } @@ -173,7 +186,7 @@ static string getTrackInfoStr(const TrackSource *track_src){ break; } } - return std::move(codec_info); + return codec_info; } const ProtocolOption &MultiMediaSourceMuxer::getOption() const { @@ -191,13 +204,16 @@ std::string MultiMediaSourceMuxer::shortUrl() const { } return _tuple.shortUrl(); } - -void MultiMediaSourceMuxer::forEachRtpSender(const std::function &cb) const { +#if defined(ENABLE_RTPPROXY) +void MultiMediaSourceMuxer::forEachRtpSender(const std::function &cb) const { for (auto &pr : _rtp_sender) { - cb(pr.first); + auto sender = std::get<1>(pr.second).lock(); + if (sender) { + cb(pr.first, *sender); + } } } - +#endif // ENABLE_RTPPROXY MultiMediaSourceMuxer::MultiMediaSourceMuxer(const MediaTuple& tuple, float dur_sec, const ProtocolOption &option): _tuple(tuple) { if (!option.stream_replace.empty()) { // 支持在on_publish hook中替换stream_id [AUTO-TRANSLATED:375eb2ff] @@ -242,6 +258,8 @@ MultiMediaSourceMuxer::MultiMediaSourceMuxer(const MediaTuple& tuple, float dur_ // Audio related settings enableAudio(option.enable_audio); enableMuteAudio(option.add_mute_audio); + + NOTICE_EMIT(BroadcastCreateMuxerArgs, Broadcast::kBroadcastCreateMuxer, _delegate, *this); } void MultiMediaSourceMuxer::setMediaListener(const std::weak_ptr &listener) { @@ -315,13 +333,13 @@ int MultiMediaSourceMuxer::totalReaderCount(MediaSource &sender) { // 此函数可能跨线程调用 [AUTO-TRANSLATED:e8c5f74d] // This function may be called across threads -bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type, bool start, const string &custom_path, size_t max_second) { +bool MultiMediaSourceMuxer::setupRecord(Recorder::type type, bool start, const string &custom_path, size_t max_second) { CHECK(getOwnerPoller(MediaSource::NullMediaSource())->isCurrentThread(), "Can only call setupRecord in it's owner poller"); onceToken token(nullptr, [&]() { if (_option.mp4_as_player && type == Recorder::type_mp4) { // 开启关闭mp4录制,触发观看人数变化相关事件 [AUTO-TRANSLATED:b63a8deb] // Turn on/off mp4 recording, trigger events related to changes in the number of viewers - onReaderChanged(sender, totalReaderCount()); + onReaderChanged(MediaSource::NullMediaSource(), totalReaderCount()); } }); switch (type) { @@ -330,7 +348,7 @@ bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type // 开始录制 [AUTO-TRANSLATED:36d99250] // Start recording _option.hls_save_path = custom_path; - auto hls = dynamic_pointer_cast(makeRecorder(sender, type)); + auto hls = dynamic_pointer_cast(makeRecorder(type)); if (hls) { // 设置HlsMediaSource的事件监听器 [AUTO-TRANSLATED:69990c92] // Set the event listener for HlsMediaSource @@ -350,7 +368,7 @@ bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type // Start recording _option.mp4_save_path = custom_path; _option.mp4_max_second = max_second; - _mp4 = makeRecorder(sender, type); + _mp4 = makeRecorder(type); } else if (!start && _mp4) { // 停止录制 [AUTO-TRANSLATED:3dee9292] // Stop recording @@ -363,7 +381,7 @@ bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type // 开始录制 [AUTO-TRANSLATED:36d99250] // Start recording _option.hls_save_path = custom_path; - auto hls = dynamic_pointer_cast(makeRecorder(sender, type)); + auto hls = dynamic_pointer_cast(makeRecorder(type)); if (hls) { // 设置HlsMediaSource的事件监听器 [AUTO-TRANSLATED:69990c92] // Set the event listener for HlsMediaSource @@ -379,7 +397,7 @@ bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type } case Recorder::type_fmp4: { if (start && !_fmp4) { - auto fmp4 = dynamic_pointer_cast(makeRecorder(sender, type)); + auto fmp4 = dynamic_pointer_cast(makeRecorder(type)); if (fmp4) { fmp4->setListener(shared_from_this()); } @@ -391,7 +409,7 @@ bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type } case Recorder::type_ts: { if (start && !_ts) { - auto ts = dynamic_pointer_cast(makeRecorder(sender, type)); + auto ts = dynamic_pointer_cast(makeRecorder(type)); if (ts) { ts->setListener(shared_from_this()); } @@ -405,9 +423,155 @@ bool MultiMediaSourceMuxer::setupRecord(MediaSource &sender, Recorder::type type } } +std::string MultiMediaSourceMuxer::startRecord(const std::string &file_path, int back_time_ms, int forward_time_ms) { +#if !defined(ENABLE_MP4) + throw std::invalid_argument("mp4相关功能未打开,请开启ENABLE_MP4宏后编译再测试"); +#else + if (!_ring) { + throw std::runtime_error("frame gop cache disabled, start record event video failed"); + } + std::string path; + if (!start_with(file_path, "/")) { + path = Recorder::getRecordPath(Recorder::type_mp4, _tuple, _option.mp4_save_path); + path += file_path; + } else { + path = file_path; + } + TraceL << "mp4 save path: " << path; + + auto muxer = std::make_shared(); + muxer->openMP4(path); + for (auto &track : MediaSink::getTracks()) { + muxer->addTrack(track); + } + muxer->addTrackCompleted(); + + bool have_history = false; + if (back_time_ms > 0) { + // 回溯录制 + std::list history; + _ring->flushGop([&](const Frame::Ptr &frame) { history.emplace_back(frame); }); + if (!history.empty()) { + auto now_dts = history.back()->dts(); + + decltype(history)::iterator pos = history.end(); + for (auto it = history.rbegin(); it != history.rend(); ++it) { + auto &frame = *it; + if (frame->getTrackType() != TrackVideo || (!frame->configFrame() && !frame->keyFrame())) { + continue; + } + // 如果视频关键帧到末尾的时长超过一定的时间,那前面的数据应该全部删除 + if (frame->dts() + back_time_ms < now_dts) { + pos = it.base(); + --pos; + break; + } + } + if (pos != history.end()) { + // 移除历史视频前面过多的数据 + DebugL << "clear history front video: " << history.front()->dts() << " -> " << (*pos)->dts(); + history.erase(history.begin(), pos); + } + + if (forward_time_ms < 0) { + // 如果后向录制时长为负,说明回溯录制要截取一段尾部 + pos = history.end(); + for (auto it = history.rbegin(); it != history.rend(); ++it) { + auto &frame = *it; + if (frame->getTrackType() != TrackVideo) { + continue; + } + if (frame->dts() < now_dts + forward_time_ms) { + pos = it.base(); + ++pos; + break; + } + } + + if (pos != history.end()) { + // 移除历史视频后面过多的数据 + DebugL << "clear history tail video: " << (*pos)->dts() << " -> " << now_dts; + history.erase(pos, history.end()); + } + } + + if (!history.empty()) { + auto &front = history.front(); + InfoL << "start record: " << path + << ", start_dts: " << front->dts() << ", key_frame: " << front->keyFrame() << ", config_frame: " << front->configFrame() + << ", now_dts: " << now_dts; + have_history = true; + } + + for (auto &frame : history) { + muxer->inputFrame(frame); + } + } + } + + if (forward_time_ms > 0) { + if (!have_history) { + InfoL << "start record: " << path << ", back_time_ms: " << back_time_ms << ", forward_time_ms: " << forward_time_ms; + } + + weak_ptr weak_self = shared_from_this(); + auto lam = [weak_self, muxer, forward_time_ms, have_history, path]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + uint64_t now_dts = 0; + int selected_index = -1; + Ticker ticker; + bool is_live_stream = strong_self->_dur_sec < 0.01; + auto reader = strong_self->_ring->attach(strong_self->MultiMediaSourceMuxer::getOwnerPoller(MediaSource::NullMediaSource()), !have_history, 1); + reader->setReadCB([muxer, now_dts, selected_index, forward_time_ms, reader, path, ticker, is_live_stream](const Frame::Ptr &frame) mutable { + if (!reader) { + // 已经关闭录制 + return; + } + // 循环引用自身 + if (!now_dts) { + now_dts = frame->dts(); + selected_index = frame->getIndex(); + } + // 新增兜底机制,如果直播录制任务时长超过预期时间3秒,不管数据时间戳是否增长是否达到预期,都强制停止录制 + if ((frame->getIndex() == selected_index && now_dts + forward_time_ms < frame->dts()) + || (is_live_stream && ticker.createdTime() > forward_time_ms + 3000ULL)) { + InfoL << "stop record: " << path << ", end dts: " << frame->dts(); + WorkThreadPool::Instance().getPoller()->async([muxer]() { muxer->closeMP4(); }); + reader = nullptr; + return; + } + muxer->inputFrame(frame); + }); + std::weak_ptr weak_reader = reader; + reader->setDetachCB([weak_reader]() { + if (auto strong_reader = weak_reader.lock()) { + // 防止循环引用 + strong_reader->setReadCB(nullptr); + } + }); + }; + if (back_time_ms >= 0) { + // 立即前向录制 + lam(); + } else { + // 延时启动录制 + MultiMediaSourceMuxer::getOwnerPoller(MediaSource::NullMediaSource())->doDelayTask(-back_time_ms, [lam]() { + lam(); + return 0; + }); + } + } + + return path; +#endif +} + // 此函数可能跨线程调用 [AUTO-TRANSLATED:e8c5f74d] // This function may be called across threads -bool MultiMediaSourceMuxer::isRecording(MediaSource &sender, Recorder::type type) { +bool MultiMediaSourceMuxer::isRecording(Recorder::type type) { switch (type) { case Recorder::type_hls: return !!_hls; case Recorder::type_mp4: return !!_mp4; @@ -418,15 +582,15 @@ bool MultiMediaSourceMuxer::isRecording(MediaSource &sender, Recorder::type type } } -void MultiMediaSourceMuxer::startSendRtp(MediaSource &sender, const MediaSourceEvent::SendRtpArgs &args, const std::function cb) { +void MultiMediaSourceMuxer::startSendRtp(const MediaSourceEvent::SendRtpArgs &args, const std::function cb) { #if defined(ENABLE_RTPPROXY) - createGopCacheIfNeed(); + createGopCacheIfNeed(1); auto ring = _ring; auto ssrc = args.ssrc; auto ssrc_multi_send = args.ssrc_multi_send; auto tracks = getTracks(false); - auto poller = getOwnerPoller(sender); + auto poller = getOwnerPoller(MediaSource::NullMediaSource()); auto rtp_sender = std::make_shared(poller); weak_ptr weak_self = shared_from_this(); @@ -443,7 +607,7 @@ void MultiMediaSourceMuxer::startSendRtp(MediaSource &sender, const MediaSourceE } }); - rtp_sender->startSend(args, [ssrc,ssrc_multi_send, weak_self, rtp_sender, cb, tracks, ring, poller](uint16_t local_port, const SockException &ex) mutable { + rtp_sender->startSend(*this, args, [ssrc,ssrc_multi_send, weak_self, rtp_sender, cb, tracks, ring, poller](uint16_t local_port, const SockException &ex) mutable { cb(local_port, ex); auto strong_self = weak_self.lock(); if (!strong_self || ex) { @@ -463,10 +627,11 @@ void MultiMediaSourceMuxer::startSendRtp(MediaSource &sender, const MediaSourceE // 可能归属线程发生变更 [AUTO-TRANSLATED:2b379e30] // The owning thread may change strong_self->getOwnerPoller(MediaSource::NullMediaSource())->async([=]() { - if(!ssrc_multi_send) { + if (!ssrc_multi_send) { strong_self->_rtp_sender.erase(ssrc); } - strong_self->_rtp_sender.emplace(ssrc,reader); + std::weak_ptr sender = rtp_sender; + strong_self->_rtp_sender.emplace(ssrc, make_tuple(reader, sender)); }); }); #else @@ -474,7 +639,7 @@ void MultiMediaSourceMuxer::startSendRtp(MediaSource &sender, const MediaSourceE #endif//ENABLE_RTPPROXY } -bool MultiMediaSourceMuxer::stopSendRtp(MediaSource &sender, const string &ssrc) { +bool MultiMediaSourceMuxer::stopSendRtp(const string &ssrc) { #if defined(ENABLE_RTPPROXY) if (ssrc.empty()) { // 关闭全部 [AUTO-TRANSLATED:ffaadfda] @@ -491,10 +656,6 @@ bool MultiMediaSourceMuxer::stopSendRtp(MediaSource &sender, const string &ssrc) #endif//ENABLE_RTPPROXY } -vector MultiMediaSourceMuxer::getMediaTracks(MediaSource &sender, bool trackReady) const { - return getTracks(trackReady); -} - EventPoller::Ptr MultiMediaSourceMuxer::getOwnerPoller(MediaSource &sender) { auto listener = getDelegate(); if (!listener) { @@ -517,6 +678,21 @@ EventPoller::Ptr MultiMediaSourceMuxer::getOwnerPoller(MediaSource &sender) { } } +bool MultiMediaSourceMuxer::close(MediaSource &sender) { + MediaSourceEventInterceptor::close(sender); + _rtmp = nullptr; + _rtsp = nullptr; + _fmp4 = nullptr; + _ts = nullptr; + _mp4 = nullptr; + _hls = nullptr; + _hls_fmp4 = nullptr; +#if defined(ENABLE_RTPPROXY) + _rtp_sender.clear(); +#endif // ENABLE_RTPPROXY + return true; +} + std::shared_ptr MultiMediaSourceMuxer::getMuxer(MediaSource &sender) const { return const_cast(this)->shared_from_this(); } @@ -550,6 +726,9 @@ bool MultiMediaSourceMuxer::onTrackReady(const Track::Ptr &track) { if (_mp4) { ret = _mp4->addTrack(track) ? true : ret; } + if (_delegate) { + _delegate->addTrack(track); + } return ret; } @@ -598,9 +777,9 @@ void MultiMediaSourceMuxer::onAllTrackReady() { } #if defined(ENABLE_RTPPROXY) - GET_CONFIG(bool, gop_cache, RtpProxy::kGopCache); - if (gop_cache) { - createGopCacheIfNeed(); + GET_CONFIG(size_t, gop_cache, RtpProxy::kGopCache); + if (gop_cache > 0) { + createGopCacheIfNeed(gop_cache); } #endif @@ -612,10 +791,13 @@ void MultiMediaSourceMuxer::onAllTrackReady() { pr.second.syncTo(*first); } } + if (_delegate) { + _delegate->addTrackCompleted(); + } InfoL << "stream: " << shortUrl() << " , codec info: " << getTrackInfoStr(this); } -void MultiMediaSourceMuxer::createGopCacheIfNeed() { +void MultiMediaSourceMuxer::createGopCacheIfNeed(size_t gop_count) { if (_ring) { return; } @@ -629,7 +811,7 @@ void MultiMediaSourceMuxer::createGopCacheIfNeed() { strong_self->onReaderChanged(*src, strong_self->totalReaderCount()); }); } - }); + }, gop_count); } void MultiMediaSourceMuxer::resetTracks() { @@ -703,6 +885,9 @@ bool MultiMediaSourceMuxer::onTrackFrame_l(const Frame::Ptr &frame_in) { if (_fmp4) { ret = _fmp4->inputFrame(frame) ? true : ret; } + if (_delegate) { + _delegate->inputFrame(frame); + } if (_ring) { // 此场景由于直接转发,可能存在切换线程引起的数据被缓存在管道,所以需要CacheAbleFrame [AUTO-TRANSLATED:528afbb7] // In this scenario, due to direct forwarding, there may be data cached in the pipeline due to thread switching, so CacheAbleFrame is needed diff --git a/src/Common/MultiMediaSourceMuxer.h b/src/Common/MultiMediaSourceMuxer.h index 660b95a9..48d3e4f5 100644 --- a/src/Common/MultiMediaSourceMuxer.h +++ b/src/Common/MultiMediaSourceMuxer.h @@ -25,10 +25,11 @@ class TSMediaSourceMuxer; class FMP4MediaSourceMuxer; class RtpSender; -class MultiMediaSourceMuxer : public MediaSourceEventInterceptor, public MediaSink, public std::enable_shared_from_this{ +class MultiMediaSourceMuxer : public MediaSourceEventInterceptor, public MediaSink, public toolkit::noncopyable, public std::enable_shared_from_this{ public: using Ptr = std::shared_ptr; using RingType = toolkit::RingBuffer; + using onCreateMuxer = std::function; class Listener { public: @@ -120,7 +121,16 @@ public: * [AUTO-TRANSLATED:cb1fd8a9] */ - bool setupRecord(MediaSource &sender, Recorder::type type, bool start, const std::string &custom_path, size_t max_second) override; + bool setupRecord(Recorder::type type, bool start, const std::string &custom_path, size_t max_second); + + /** + * 开始录制mp4 + * @param file_path mp4相对路径 + * @param back_time_ms 回溯录制时长 + * @param forward_time_ms 后续录制时长 + * @return 录制文件绝对路径 + */ + std::string startRecord(const std::string &file_path, int back_time_ms, int forward_time_ms); /** * 获取录制状态 @@ -132,25 +142,13 @@ public: * [AUTO-TRANSLATED:798afa71] */ - bool isRecording(MediaSource &sender, Recorder::type type) override; + bool isRecording(Recorder::type type); /** * 开始发送ps-rtp流 - * @param dst_url 目标ip或域名 - * @param dst_port 目标端口 - * @param ssrc rtp的ssrc - * @param is_udp 是否为udp * @param cb 启动成功或失败回调 - * Start sending ps-rtp stream - * @param dst_url Target ip or domain name - * @param dst_port Target port - * @param ssrc rtp's ssrc - * @param is_udp Whether it is udp - * @param cb Start success or failure callback - - * [AUTO-TRANSLATED:620416c2] */ - void startSendRtp(MediaSource &sender, const MediaSourceEvent::SendRtpArgs &args, const std::function cb) override; + void startSendRtp(const MediaSourceEvent::SendRtpArgs &args, const std::function cb); /** * 停止ps-rtp发送 @@ -160,19 +158,7 @@ public: * [AUTO-TRANSLATED:b91e2055] */ - bool stopSendRtp(MediaSource &sender, const std::string &ssrc) override; - - /** - * 获取所有Track - * @param trackReady 是否筛选过滤未就绪的track - * @return 所有Track - * Get all Tracks - * @param trackReady Whether to filter out unready tracks - * @return All Tracks - - * [AUTO-TRANSLATED:53755f5d] - */ - std::vector getMediaTracks(MediaSource &sender, bool trackReady = true) const override; + bool stopSendRtp(const std::string &ssrc); /** * 获取所属线程 @@ -181,6 +167,12 @@ public: * [AUTO-TRANSLATED:a4dc847e] */ toolkit::EventPoller::Ptr getOwnerPoller(MediaSource &sender) override; + + /** + * 关闭流 + * @return 是否成功 + */ + bool close(MediaSource &sender) override; /** * 获取本对象 @@ -193,9 +185,9 @@ public: const ProtocolOption &getOption() const; const MediaTuple &getMediaTuple() const; std::string shortUrl() const; - - void forEachRtpSender(const std::function &cb) const; - +#if defined(ENABLE_RTPPROXY) + void forEachRtpSender(const std::function &cb) const; +#endif // ENABLE_RTPPROXY protected: /////////////////////////////////MediaSink override///////////////////////////////// @@ -231,8 +223,8 @@ protected: bool onTrackFrame_l(const Frame::Ptr &frame); private: - void createGopCacheIfNeed(); - std::shared_ptr makeRecorder(MediaSource &sender, Recorder::type type); + void createGopCacheIfNeed(size_t gop_count); + std::shared_ptr makeRecorder(Recorder::type type); private: bool _is_enable = false; @@ -245,7 +237,9 @@ private: toolkit::Ticker _last_check; std::unordered_map _stamps; std::weak_ptr _track_listener; - std::unordered_multimap _rtp_sender; +#if defined(ENABLE_RTPPROXY) + std::unordered_multimap>> _rtp_sender; +#endif // ENABLE_RTPPROXY std::shared_ptr _fmp4; std::shared_ptr _rtmp; std::shared_ptr _rtsp; @@ -257,6 +251,8 @@ private: toolkit::EventPoller::Ptr _poller; RingType::Ptr _ring; + MediaSinkInterface::Ptr _delegate; + // 对象个数统计 [AUTO-TRANSLATED:3b43e8c2] // Object count statistics toolkit::ObjectStatistic _statistic; diff --git a/src/Common/Stamp.cpp b/src/Common/Stamp.cpp index 2cc23dfc..eb1e335d 100644 --- a/src/Common/Stamp.cpp +++ b/src/Common/Stamp.cpp @@ -87,13 +87,13 @@ void Stamp::setPlayBack(bool playback) { _playback = playback; } -void Stamp::syncTo(Stamp &other) { - _need_sync = true; +void Stamp::syncTo(Stamp &other, int count) { + _need_sync += count; _sync_master = &other; } void Stamp::needSync() { - _need_sync = true; + ++_need_sync; } void Stamp::enableRollback(bool flag) { @@ -145,25 +145,43 @@ void Stamp::revise_l(int64_t dts, int64_t pts, int64_t &dts_out, int64_t &pts_ou // 音视频dts当前时间差 [AUTO-TRANSLATED:716468a6] // Audio and video dts current time difference int64_t dts_diff = _last_dts_in - _sync_master->_last_dts_in; - if (ABS(dts_diff) < 5000) { + if (ABS(dts_diff) < 5000 || _need_sync > 3) { + // 两种时间戳相差不得大于300ms + dts_diff = _relative_stamp - _sync_master->_relative_stamp; + // 强制同步音视频 + if (dts_diff > 300) { + dts_diff = 0; + } else if (dts_diff < -300) { + dts_diff = 0; + } // 如果绝对时间戳小于5秒,那么说明他们的起始时间戳是一致的,那么强制同步 [AUTO-TRANSLATED:5d11ef6a] // If the absolute timestamp is less than 5 seconds, then it means that their starting timestamps are consistent, then force synchronization auto target_stamp = _sync_master->_relative_stamp + dts_diff; if (target_stamp > _relative_stamp || _enable_rollback) { // 强制同步后,时间戳增加跳跃了,或允许回退 [AUTO-TRANSLATED:805424a9] // After forced synchronization, the timestamp increases jump, or allows rollback + if (_relative_stamp == target_stamp) { + return; + } TraceL << "Relative stamp changed: " << _relative_stamp << " -> " << target_stamp; _relative_stamp = target_stamp; } else { // 不允许回退, 则让另外一个Track的时间戳增长 [AUTO-TRANSLATED:428e8ce2] // Not allowed to rollback, then let the timestamp of the other Track increase target_stamp = _relative_stamp - dts_diff; + if (_sync_master->_relative_stamp == target_stamp) { + return; + } TraceL << "Relative stamp changed: " << _sync_master->_relative_stamp << " -> " << target_stamp; _sync_master->_relative_stamp = target_stamp; } } - _need_sync = false; - _sync_master->_need_sync = false; + if (_need_sync) { + --_need_sync; + } + if (_sync_master->_need_sync) { + --_sync_master->_need_sync; + } } } @@ -302,7 +320,7 @@ bool DtsGenerator::getDts_l(uint64_t pts, uint64_t &dts) { // Put pts into the sorting cache queue, the maximum cache queue is equal to the number of consecutive B frames _pts_sorter.emplace(pts); - if (_sorter_max_size && _pts_sorter.size() > _sorter_max_size) { + if (_sorter_max_size > 1 && _pts_sorter.size() > _sorter_max_size) { // 如果启用了pts排序(意味着存在B帧),并且pts排序缓存列队长度大于连续B帧个数, [AUTO-TRANSLATED:002c0d03] // If pts sorting is enabled (meaning there are B frames), and the length of the pts sorting cache queue is greater than the number of consecutive B frames, // 意味着后续的pts都会比最早的pts大,那么说明可以取出最早的pts了,这个pts将当做该帧的dts基准 [AUTO-TRANSLATED:86b8f679] diff --git a/src/Common/Stamp.h b/src/Common/Stamp.h index beeadcb9..3ca34a0e 100644 --- a/src/Common/Stamp.h +++ b/src/Common/Stamp.h @@ -117,7 +117,7 @@ public: * [AUTO-TRANSLATED:7ac41a76] */ - void syncTo(Stamp &other); + void syncTo(Stamp &other, int count = 1); /** * 是否允许时间戳回退 @@ -145,7 +145,7 @@ private: private: bool _playback = false; - bool _need_sync = false; + int _need_sync = 0; // 默认不允许时间戳回滚 [AUTO-TRANSLATED:0163ff03] // Default does not allow timestamp rollback bool _enable_rollback = false; diff --git a/src/Common/config.cpp b/src/Common/config.cpp index b83aaa55..eb9743df 100644 --- a/src/Common/config.cpp +++ b/src/Common/config.cpp @@ -30,7 +30,22 @@ bool loadIniConfig(const char *ini_path) { ini = exePath() + ".ini"; } try { - mINI::Instance().parseFile(ini); + mINI tmp; + tmp.parseFile(ini); + + auto &ref = mINI::Instance(); + for (auto &pr : tmp) { + if (ref.find(pr.first) == ref.end()) { + // 新增键 + WarnL << "unknow config: " << pr.first << " = " << pr.second; + ref.emplace(pr); + } else { + // 更新键 + ref[pr.first] = pr.second; + } + } + // 更新注释和排序 + ref.updateFrom(tmp); NOTICE_EMIT(BroadcastReloadConfigArgs, Broadcast::kBroadcastReloadConfig); return true; } catch (std::exception &) { @@ -66,6 +81,8 @@ const string kBroadcastRtcSctpClosed = "kBroadcastRtcSctpClosed"; const string kBroadcastRtcSctpSend = "kBroadcastRtcSctpSend"; const string kBroadcastRtcSctpReceived = "kBroadcastRtcSctpReceived"; const string kBroadcastPlayerCountChanged = "kBroadcastPlayerCountChanged"; +const string kBroadcastPlayerProxyFailed = "kBroadcastPlayerProxyFailed"; +const string kBroadcastCreateMuxer = "kBroadcastCreateMuxer"; } // namespace Broadcast @@ -265,7 +282,7 @@ static onceToken token([]() { mINI::Instance()[kHandshakeSecond] = 15; mINI::Instance()[kKeepAliveSecond] = 15; mINI::Instance()[kDirectProxy] = 1; - mINI::Instance()[kEnhanced] = 0; + mINI::Instance()[kEnhanced] = 1; }); } // namespace Rtmp @@ -375,6 +392,7 @@ const string kOpusPT = RTP_PROXY_FIELD "opus_pt"; const string kGopCache = RTP_PROXY_FIELD "gop_cache"; const string kRtpG711DurMs = RTP_PROXY_FIELD "rtp_g711_dur_ms"; const string kUdpRecvSocketBuffer = RTP_PROXY_FIELD "udp_recv_socket_buffer"; +const std::string kMergeFrame = RTP_PROXY_FIELD "merge_frame"; static onceToken token([]() { mINI::Instance()[kDumpDir] = ""; @@ -387,6 +405,7 @@ static onceToken token([]() { mINI::Instance()[kGopCache] = 1; mINI::Instance()[kRtpG711DurMs] = 100; mINI::Instance()[kUdpRecvSocketBuffer] = 4 * 1024 * 1024; + mINI::Instance()[kMergeFrame] = 1; }); } // namespace RtpProxy @@ -405,8 +424,10 @@ const string kWaitTrackReady = "wait_track_ready"; const string kPlayTrack = "play_track"; const string kProxyUrl = "proxy_url"; const string kRtspSpeed = "rtsp_speed"; +const string kSchema = "schema"; const string kLatency = "latency"; const string kPassPhrase = "passPhrase"; +const string kCustomHeader = "custom_header"; } // namespace Client } // namespace mediakit diff --git a/src/Common/config.h b/src/Common/config.h index a22eb67c..3becfcf1 100644 --- a/src/Common/config.h +++ b/src/Common/config.h @@ -55,31 +55,31 @@ extern const std::string kBroadcastRecordTs; // 收到http api请求广播 [AUTO-TRANSLATED:c72e7c3f] // Broadcast for receiving http api request extern const std::string kBroadcastHttpRequest; -#define BroadcastHttpRequestArgs const Parser &parser, const HttpSession::HttpResponseInvoker &invoker, bool &consumed, SockInfo &sender +#define BroadcastHttpRequestArgs const Parser &parser, const HttpSession::HttpResponseInvoker &invoker, bool &consumed, toolkit::SockInfo &sender // 在http文件服务器中,收到http访问文件或目录的广播,通过该事件控制访问http目录的权限 [AUTO-TRANSLATED:2de426b4] // In the http file server, broadcast for receiving http access to files or directories. Control access permissions to the http directory through this event. extern const std::string kBroadcastHttpAccess; -#define BroadcastHttpAccessArgs const Parser &parser, const std::string &path, const bool &is_dir, const HttpSession::HttpAccessPathInvoker &invoker, SockInfo &sender +#define BroadcastHttpAccessArgs const Parser &parser, const std::string &path, const std::string &file_path, const bool &is_dir, const HttpSession::HttpAccessPathInvoker &invoker, toolkit::SockInfo &sender // 在http文件服务器中,收到http访问文件或目录前的广播,通过该事件可以控制http url到文件路径的映射 [AUTO-TRANSLATED:0294d0c5] // In the http file server, broadcast before receiving http access to files or directories. Control the mapping from http url to file path through this event. // 在该事件中通过自行覆盖path参数,可以做到譬如根据虚拟主机或者app选择不同http根目录的目的 [AUTO-TRANSLATED:1bea3efb] // By overriding the path parameter in this event, you can achieve the purpose of selecting different http root directories based on virtual hosts or apps. extern const std::string kBroadcastHttpBeforeAccess; -#define BroadcastHttpBeforeAccessArgs const Parser &parser, std::string &path, SockInfo &sender +#define BroadcastHttpBeforeAccessArgs const Parser &parser, std::string &path, toolkit::SockInfo &sender // 该流是否需要认证?是的话调用invoker并传入realm,否则传入空的realm.如果该事件不监听则不认证 [AUTO-TRANSLATED:5f436d8f] // Does this stream need authentication? If yes, call invoker and pass in realm, otherwise pass in an empty realm. If this event is not listened to, no authentication will be performed. extern const std::string kBroadcastOnGetRtspRealm; -#define BroadcastOnGetRtspRealmArgs const MediaInfo &args, const RtspSession::onGetRealm &invoker, SockInfo &sender +#define BroadcastOnGetRtspRealmArgs const MediaInfo &args, const RtspSession::onGetRealm &invoker, toolkit::SockInfo &sender // 请求认证用户密码事件,user_name为用户名,must_no_encrypt如果为true,则必须提供明文密码(因为此时是base64认证方式),否则会导致认证失败 [AUTO-TRANSLATED:22b6dfcc] // Request authentication user password event, user_name is the username, must_no_encrypt if true, then the plaintext password must be provided (because it is base64 authentication method at this time), otherwise it will lead to authentication failure. // 获取到密码后请调用invoker并输入对应类型的密码和密码类型,invoker执行时会匹配密码 [AUTO-TRANSLATED:8c57fd43] // After getting the password, please call invoker and input the corresponding type of password and password type. The invoker will match the password when executing. extern const std::string kBroadcastOnRtspAuth; -#define BroadcastOnRtspAuthArgs const MediaInfo &args, const std::string &realm, const std::string &user_name, const bool &must_no_encrypt, const RtspSession::onAuth &invoker, SockInfo &sender +#define BroadcastOnRtspAuthArgs const MediaInfo &args, const std::string &realm, const std::string &user_name, const bool &must_no_encrypt, const RtspSession::onAuth &invoker, toolkit::SockInfo &sender // 推流鉴权结果回调对象 [AUTO-TRANSLATED:7e508ed1] // Push stream authentication result callback object @@ -90,7 +90,7 @@ using PublishAuthInvoker = std::function; // 播放rtsp/rtmp/http-flv事件广播,通过该事件控制播放鉴权 [AUTO-TRANSLATED:eddd7014] // Broadcast for playing rtsp/rtmp/http-flv events. Control playback authentication through this event. extern const std::string kBroadcastMediaPlayed; -#define BroadcastMediaPlayedArgs const MediaInfo &args, const Broadcast::AuthInvoker &invoker, SockInfo &sender +#define BroadcastMediaPlayedArgs const MediaInfo &args, const Broadcast::AuthInvoker &invoker, toolkit::SockInfo &sender // shell登录鉴权 [AUTO-TRANSLATED:26b135d4] // Shell login authentication extern const std::string kBroadcastShellLogin; -#define BroadcastShellLoginArgs const std::string &user_name, const std::string &passwd, const Broadcast::AuthInvoker &invoker, SockInfo &sender +#define BroadcastShellLoginArgs const std::string &user_name, const std::string &passwd, const Broadcast::AuthInvoker &invoker, toolkit::SockInfo &sender // 停止rtsp/rtmp/http-flv会话后流量汇报事件广播 [AUTO-TRANSLATED:69df61d8] // Broadcast for traffic reporting event after stopping rtsp/rtmp/http-flv session extern const std::string kBroadcastFlowReport; -#define BroadcastFlowReportArgs const MediaInfo &args, const uint64_t &totalBytes, const uint64_t &totalDuration, const bool &isPlayer, SockInfo &sender +#define BroadcastFlowReportArgs const MediaInfo &args, const uint64_t &totalBytes, const uint64_t &totalDuration, const bool &isPlayer, toolkit::SockInfo &sender // 未找到流后会广播该事件,请在监听该事件后去拉流或其他方式产生流,这样就能按需拉流了 [AUTO-TRANSLATED:0c00171d] // This event will be broadcast after the stream is not found. Please pull the stream or other methods to generate the stream after listening to this event, so that you can pull the stream on demand. extern const std::string kBroadcastNotFoundStream; -#define BroadcastNotFoundStreamArgs const MediaInfo &args, SockInfo &sender, const std::function &closePlayer +#define BroadcastNotFoundStreamArgs const MediaInfo &args, toolkit::SockInfo &sender, const std::function &closePlayer // 某个流无人消费时触发,目的为了实现无人观看时主动断开拉流等业务逻辑 [AUTO-TRANSLATED:3c45f002] // Triggered when a stream is not consumed by anyone. The purpose is to achieve business logic such as actively disconnecting the pull stream when no one is watching. @@ -126,7 +126,7 @@ extern const std::string kBroadcastStreamNoneReader; // rtp推流被动停止时触发 [AUTO-TRANSLATED:43881965] // Triggered when rtp push stream is passively stopped. extern const std::string kBroadcastSendRtpStopped; -#define BroadcastSendRtpStoppedArgs MultiMediaSourceMuxer &sender, const std::string &ssrc, const SockException &ex +#define BroadcastSendRtpStoppedArgs MultiMediaSourceMuxer &sender, const std::string &ssrc, const toolkit::SockException &ex // 更新配置文件事件广播,执行loadIniConfig函数加载配置文件成功后会触发该广播 [AUTO-TRANSLATED:ad4e167d] // Update configuration file event broadcast. This broadcast will be triggered after the loadIniConfig function loads the configuration file successfully. @@ -161,6 +161,12 @@ extern const std::string kBroadcastRtcSctpReceived; extern const std::string kBroadcastPlayerCountChanged; #define BroadcastPlayerCountChangedArgs const MediaTuple& args, const int& count +extern const std::string kBroadcastPlayerProxyFailed; +#define BroadcastPlayerProxyFailedArgs const PlayerProxy& sender, const toolkit::SockException &ex + +extern const std::string kBroadcastCreateMuxer; +#define BroadcastCreateMuxerArgs MediaSinkInterface::Ptr &delegate, const MultiMediaSourceMuxer &sender + #define ReloadConfigTag ((void *)(0xFF)) #define RELOAD_KEY(arg, key) \ do { \ @@ -560,8 +566,8 @@ extern const std::string kPSPT; // rtp server opus 的pt [AUTO-TRANSLATED:9f91f85a] // Rtp server opus pt extern const std::string kOpusPT; -// RtpSender相关功能是否提前开启gop缓存优化级联秒开体验,默认开启 [AUTO-TRANSLATED:40c37c77] -// Whether to enable gop cache optimization cascade second-open experience for RtpSender related functions, enabled by default +// startSendRtp、startRecord相关功能是否提前开启gop缓存优化级联秒开体验,默认开启, 并缓存1个GOP [AUTO-TRANSLATED:40c37c77] +// Whether to enable gop cache optimization cascade second-open experience for startSendRtp/startRecord related functions, enabled by default, and cached 1 GOP extern const std::string kGopCache; // 国标发送g711 rtp 打包时,每个包的语音时长是多少,默认是100 ms,范围为20~180ms (gb28181-2016,c.2.4规定), [AUTO-TRANSLATED:3b3916a3] // When sending g711 rtp packets in national standard, what is the duration of each packet, the default is 100 ms, the range is 20~180ms (gb28181-2016, c.2.4), @@ -570,6 +576,8 @@ extern const std::string kGopCache; extern const std::string kRtpG711DurMs; // udp recv socket buffer size extern const std::string kUdpRecvSocketBuffer; +// ps/ts解析后是否等待下一帧以判断本帧是否完整,开启后提高兼容性,但是可能增加延时 +extern const std::string kMergeFrame; } // namespace RtpProxy /** @@ -635,6 +643,10 @@ extern const std::string kRtspSpeed; extern const std::string kLatency; // Set SRT PassPhrase extern const std::string kPassPhrase; +// 自定义rtsp/http头 +extern const std::string kCustomHeader; +// 指定采用什么播放协议 +extern const std::string kSchema; } // namespace Client } // namespace mediakit diff --git a/src/Common/macros.cpp b/src/Common/macros.cpp index e9574dcc..51a5eda2 100644 --- a/src/Common/macros.cpp +++ b/src/Common/macros.cpp @@ -30,7 +30,7 @@ namespace mediakit { * [AUTO-TRANSLATED:f214f734] */ #if !defined(ENABLE_VERSION) -const char kServerName[] = "ZLMediaKit-8.0(build in " __DATE__ " " __TIME__ ")"; +const char kServerName[] = "ZLMediaKit-9.0(build in " __DATE__ " " __TIME__ ")"; #else const char kServerName[] = "ZLMediaKit(git hash:" COMMIT_HASH "/" COMMIT_TIME ",branch:" BRANCH_NAME ",build time:" BUILD_TIME ")"; #endif diff --git a/src/Common/macros.h b/src/Common/macros.h index d6d6d683..4bca5b7e 100644 --- a/src/Common/macros.h +++ b/src/Common/macros.h @@ -61,6 +61,7 @@ item = 0; \ } #endif // CLEAR_ARR + #define RTC_SCHEMA "rtc" #define RTSP_SCHEMA "rtsp" #define RTMP_SCHEMA "rtmp" diff --git a/src/Common/strCoding.cpp b/src/Common/strCoding.cpp index 5eafe40f..6a6b3ae4 100644 --- a/src/Common/strCoding.cpp +++ b/src/Common/strCoding.cpp @@ -38,14 +38,14 @@ void UnicodeToUTF8(char *pOut, const wchar_t *pText) { return; } -char HexCharToBin(char ch) { +signed char HexCharToBin(char ch) { if (ch >= '0' && ch <= '9') return (char)(ch - '0'); if (ch >= 'a' && ch <= 'f') return (char)(ch - 'a' + 10); if (ch >= 'A' && ch <= 'F') return (char)(ch - 'A' + 10); return -1; } -char HexStrToBin(const char *str) { +signed char HexStrToBin(const char *str) { auto high = HexCharToBin(str[0]); auto low = HexCharToBin(str[1]); if (high == -1 || low == -1) { @@ -81,7 +81,7 @@ static string UrlDecodeCommon(const string &str,const char* dont_unescape){ output.append(str, i, len - i); break; } - char ch = HexStrToBin(&(str[i + 1])); + signed char ch = HexStrToBin(&(str[i + 1])); if (ch == -1 || strchr(dont_unescape, (unsigned char)ch) != NULL) { // %后面两个字节不是16进制字符串,转义失败;或者转义出来可能会造成url包含非path部分,比如#?,说明提交的是非法拼接的url;直接拼接3个原始字符 [AUTO-TRANSLATED:7c734054] // The two bytes after % are not hexadecimal strings, the escape fails; or the escaped result may cause the url to contain non-path parts, such as #?, indicating that the submitted url is illegally spliced; directly splice the three original characters @@ -142,7 +142,7 @@ std::string strCoding::UrlDecodeComponent(const std::string &str) { output.append(str, i, len - i); break; } - char ch = HexStrToBin(&(str[i + 1])); + signed char ch = HexStrToBin(&(str[i + 1])); if (ch == -1) { // %后面两个字节不是16进制字符串,转义失败;直接拼接3个原始字符 [AUTO-TRANSLATED:10e614a4] // The two bytes after % are not hexadecimal strings, the escape fails; directly splice the three original characters diff --git a/src/Extension/CommonRtp.cpp b/src/Extension/CommonRtp.cpp index fd3b8abd..b4687d07 100644 --- a/src/Extension/CommonRtp.cpp +++ b/src/Extension/CommonRtp.cpp @@ -28,6 +28,8 @@ bool CommonRtpDecoder::inputRtp(const RtpPacket::Ptr &rtp, bool){ if (payload_size <= 0) { // 无实际负载 [AUTO-TRANSLATED:305af48f] // No actual load + // 无实际负载也需要记录序号,否则会误判丢包 + _last_seq = rtp->getSeq(); return false; } auto payload = rtp->getPayload(); @@ -93,4 +95,4 @@ bool CommonRtpEncoder::inputFrame(const Frame::Ptr &frame){ is_key = false; } return len > 0; -} \ No newline at end of file +} diff --git a/src/Extension/Factory.cpp b/src/Extension/Factory.cpp index d09d18a1..80affb79 100644 --- a/src/Extension/Factory.cpp +++ b/src/Extension/Factory.cpp @@ -21,8 +21,11 @@ namespace mediakit { static std::unordered_map s_plugins; +REGISTER_CODEC(vp8_plugin); +REGISTER_CODEC(vp9_plugin); REGISTER_CODEC(h264_plugin); REGISTER_CODEC(h265_plugin); +REGISTER_CODEC(av1_plugin); REGISTER_CODEC(jpeg_plugin); REGISTER_CODEC(aac_plugin); REGISTER_CODEC(opus_plugin); @@ -30,6 +33,8 @@ REGISTER_CODEC(g711a_plugin) REGISTER_CODEC(g711u_plugin); REGISTER_CODEC(l16_plugin); REGISTER_CODEC(mp3_plugin); +REGISTER_CODEC(mp2v_plugin); +REGISTER_CODEC(mp2a_plugin); void Factory::registerPlugin(const CodecPlugin &plugin) { InfoL << "Load codec: " << getCodecName(plugin.getCodec()); @@ -96,10 +101,15 @@ static CodecId getVideoCodecIdByAmf(const AMFValue &val) { if (val.type() != AMF_NULL) { auto type_id = (RtmpVideoCodec)val.as_integer(); switch (type_id) { + case RtmpVideoCodec::fourcc_avc1: case RtmpVideoCodec::h264: return CodecH264; case RtmpVideoCodec::fourcc_hevc: case RtmpVideoCodec::h265: return CodecH265; + case RtmpVideoCodec::av1: case RtmpVideoCodec::fourcc_av1: return CodecAV1; + case RtmpVideoCodec::vp8: + case RtmpVideoCodec::fourcc_vp8: return CodecVP8; + case RtmpVideoCodec::vp9: case RtmpVideoCodec::fourcc_vp9: return CodecVP9; default: WarnL << "Unsupported codec: " << (int)type_id; return CodecInvalid; } @@ -152,7 +162,8 @@ static CodecId getAudioCodecIdByAmf(const AMFValue &val) { case RtmpAudioCodec::adpcm: return CodecADPCM; case RtmpAudioCodec::g711a: return CodecG711A; case RtmpAudioCodec::g711u: return CodecG711U; - case RtmpAudioCodec::opus: return CodecOpus; + case RtmpAudioCodec::opus: + case RtmpAudioCodec::fourcc_opus: return CodecOpus; default: WarnL << "Unsupported codec: " << (int)type_id; return CodecInvalid; } } @@ -190,15 +201,16 @@ AMFValue Factory::getAmfByCodecId(CodecId codecId) { GET_CONFIG(bool, enhanced, Rtmp::kEnhanced); switch (codecId) { case CodecAAC: return AMFValue((int)RtmpAudioCodec::aac); - case CodecH264: return AMFValue((int)RtmpVideoCodec::h264); + case CodecH264: return enhanced ? AMFValue((int)RtmpVideoCodec::fourcc_avc1) : AMFValue((int)RtmpVideoCodec::h264); case CodecH265: return enhanced ? AMFValue((int)RtmpVideoCodec::fourcc_hevc) : AMFValue((int)RtmpVideoCodec::h265); case CodecG711A: return AMFValue((int)RtmpAudioCodec::g711a); case CodecG711U: return AMFValue((int)RtmpAudioCodec::g711u); - case CodecOpus: return AMFValue((int)RtmpAudioCodec::opus); + case CodecOpus: return enhanced ? AMFValue((int)RtmpAudioCodec::fourcc_opus) : AMFValue((int)RtmpAudioCodec::opus); case CodecADPCM: return AMFValue((int)RtmpAudioCodec::adpcm); case CodecMP3: return AMFValue((int)RtmpAudioCodec::mp3); - case CodecAV1: return AMFValue((int)RtmpVideoCodec::fourcc_av1); - case CodecVP9: return AMFValue((int)RtmpVideoCodec::fourcc_vp9); + case CodecAV1: return enhanced ? AMFValue((int)RtmpVideoCodec::fourcc_av1) : AMFValue((int)RtmpVideoCodec::av1); + case CodecVP8: return enhanced ? AMFValue((int)RtmpVideoCodec::fourcc_vp8) : AMFValue((int)RtmpVideoCodec::vp8); + case CodecVP9: return enhanced ? AMFValue((int)RtmpVideoCodec::fourcc_vp9) : AMFValue((int)RtmpVideoCodec::vp9); default: return AMFValue(AMF_NULL); } } diff --git a/src/Extension/Frame.h b/src/Extension/Frame.h index ca38b5a7..d78cbe18 100644 --- a/src/Extension/Frame.h +++ b/src/Extension/Frame.h @@ -54,7 +54,9 @@ typedef enum { XX(CodecG722, TrackAudio, 18, "G722", PSI_STREAM_AUDIO_G722, MOV_OBJECT_NONE) \ XX(CodecG723, TrackAudio, 19, "G723", PSI_STREAM_AUDIO_G723, MOV_OBJECT_NONE) \ XX(CodecG728, TrackAudio, 20, "G728", PSI_STREAM_RESERVED, MOV_OBJECT_NONE) \ - XX(CodecG729, TrackAudio, 21, "G729", PSI_STREAM_AUDIO_G729, MOV_OBJECT_NONE) + XX(CodecG729, TrackAudio, 21, "G729", PSI_STREAM_AUDIO_G729, MOV_OBJECT_NONE) \ + XX(CodecMP2V, TrackVideo, 22, "MPV", PSI_STREAM_MPEG2, MOV_OBJECT_MP2V) \ + XX(CodecMP2A, TrackAudio, 23, "MPA", PSI_STREAM_AUDIO_MPEG1, MOV_OBJECT_MP3) typedef enum { CodecInvalid = -1, @@ -733,9 +735,9 @@ public: * [AUTO-TRANSLATED:a3e7e6db] */ bool inputFrame(const Frame::Ptr &frame) override { - std::lock_guard lck(_mtx); doStatistics(frame); bool ret = false; + std::lock_guard lck(_mtx); for (auto &pr : _delegates) { if (pr.second->inputFrame(frame)) { ret = true; @@ -767,7 +769,6 @@ public: * [AUTO-TRANSLATED:73cb2ab0] */ uint64_t getVideoKeyFrames() const { - std::lock_guard lck(_mtx); return _video_key_frames; } @@ -778,22 +779,18 @@ public: * [AUTO-TRANSLATED:118b395e] */ uint64_t getFrames() const { - std::lock_guard lck(_mtx); return _frames; } size_t getVideoGopSize() const { - std::lock_guard lck(_mtx); return _gop_size; } size_t getVideoGopInterval() const { - std::lock_guard lck(_mtx); return _gop_interval_ms; } int64_t getDuration() const { - std::lock_guard lck(_mtx); return _stamp.getRelativeStamp(); } diff --git a/src/Extension/Track.cpp b/src/Extension/Track.cpp index 4e9e424c..6048632f 100644 --- a/src/Extension/Track.cpp +++ b/src/Extension/Track.cpp @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. * * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). diff --git a/src/Extension/Track.h b/src/Extension/Track.h index 212bd7fc..a9c3531b 100644 --- a/src/Extension/Track.h +++ b/src/Extension/Track.h @@ -195,16 +195,21 @@ public: _fps = fps; } + VideoTrackImp(CodecId codec_id) { + _codec_id = codec_id; + _fps = 30; + } + int getVideoWidth() const override { return _width; } int getVideoHeight() const override { return _height; } float getVideoFps() const override { return _fps; } - bool ready() const override { return true; } + bool ready() const override { return _width > 0 && _height > 0; } Track::Ptr clone() const override { return std::make_shared(*this); } Sdp::Ptr getSdp(uint8_t payload_type) const override; CodecId getCodecId() const override { return _codec_id; } -private: +protected: CodecId _codec_id; int _width = 0; int _height = 0; @@ -324,7 +329,7 @@ public: Track::Ptr clone() const override { return std::make_shared(*this); } Sdp::Ptr getSdp(uint8_t payload_type) const override; -private: +protected: CodecId _codecid; int _sample_rate; int _channels; diff --git a/src/Http/HlsPlayer.cpp b/src/Http/HlsPlayer.cpp index 2443297e..4c5f8eaf 100644 --- a/src/Http/HlsPlayer.cpp +++ b/src/Http/HlsPlayer.cpp @@ -22,6 +22,8 @@ HlsPlayer::HlsPlayer(const EventPoller::Ptr &poller) { void HlsPlayer::play(const string &url) { _play_result = false; _play_url = url; + _last_sequence = -1; + _playlist_reload_changed = true; setProxyUrl((*this)[Client::kProxyUrl]); setAllowResendRequest(true); fetchIndexFile(); @@ -36,6 +38,7 @@ void HlsPlayer::fetchIndexFile() { } setCompleteTimeout((*this)[Client::kTimeoutMS].as()); setMethod("GET"); + addCustomHeader(this); sendRequest(_play_url); } @@ -133,6 +136,10 @@ void HlsPlayer::fetchSegment() { if (!(*this)[Client::kNetAdapter].empty()) { _http_ts_player->setNetAdapter((*this)[Client::kNetAdapter]); } + } else { + // 每次请求新的ts片段时重置HttpTSPlayer状态 + _http_ts_player->clear(); + _http_ts_player->setProxyUrl((*this)[Client::kProxyUrl]); } Ticker ticker; @@ -197,10 +204,12 @@ void HlsPlayer::fetchSegment() { bool HlsPlayer::onParsed(bool is_m3u8_inner, int64_t sequence, const map &ts_map) { if (!is_m3u8_inner) { + auto playlist_changed = _last_sequence != sequence; + _playlist_reload_changed = playlist_changed; // 这是ts播放列表 [AUTO-TRANSLATED:7ce3d81b] // This is the ts playlist // This is the ts playlist - if (_last_sequence == sequence) { + if (!playlist_changed) { // 如果是重复的ts列表,那么忽略 [AUTO-TRANSLATED:d15a47f3] // If it is a duplicate ts list, then ignore it // 但是需要注意, 如果当前ts列表为空了, 那么表明直播结束了或者m3u8文件有问题,需要重新拉流 [AUTO-TRANSLATED:438a8df0] @@ -278,6 +287,7 @@ void HlsPlayer::onResponseHeader(const string &status, const HttpClient::HttpHea void HlsPlayer::onResponseBody(const char *buf, size_t size) { _m3u8.append(buf, size); + _recvtotalbytes += getRecvTotalBytes(); } void HlsPlayer::onResponseCompleted(const SockException &ex) { @@ -306,27 +316,23 @@ float HlsPlayer::delaySecond() { if (HlsParser::isM3u8() && HlsParser::getTargetDur() > 0) { float targetOffset; if (HlsParser::isLive()) { - // see RFC 8216, Section 4.4.3.8. - // 根据rfc刷新index列表的周期应该是分段时间x3, 因为根据规范播放器只处理最后3个Segment [AUTO-TRANSLATED:07168708] - // According to the rfc, the refresh cycle of the index list should be 3 times the segment time, because according to the specification, the player only processes the last 3 Segments - // refresh the index list according to rfc cycle should be the segment time x3, - // because according to the specification, the player only handles the last 3 segments - targetOffset = (float)(3 * HlsParser::getTargetDur()); + // RFC 8216 Section 6.3.4: + // after a changed playlist reload, wait at least one target + // duration; after an unchanged reload, wait half a target duration. + return _playlist_reload_changed ? (float) HlsParser::getTargetDur() + : (float) HlsParser::getTargetDur() / 2.0f; } else { // 点播则一般m3u8文件不会在改变了, 没必要频繁的刷新, 所以按照总时间来进行刷新 [AUTO-TRANSLATED:2ac0a29e] // On-demand generally does not change the m3u8 file, there is no need to refresh frequently, so refresh according to the total time - // On-demand, the m3u8 file will generally not change, so there is no need to refresh frequently, targetOffset = HlsParser::getTotalDuration(); } // 取最小值, 避免因为分段时长不规则而导致的问题 [AUTO-TRANSLATED:073dff48] // Take the minimum value to avoid problems caused by irregular segment durations - // Take the minimum value to avoid problems caused by irregular segment duration if (targetOffset > HlsParser::getTotalDuration()) { targetOffset = HlsParser::getTotalDuration(); } // 根据规范为一半的时间 [AUTO-TRANSLATED:07652637] // According to the specification, it is half the time - // According to the specification, it is half the time if (targetOffset / 2 > 1.0f) { return targetOffset / 2; } @@ -353,6 +359,13 @@ void HlsPlayer::playDelay(float delay_sec) { }, getPoller())); } +size_t HlsPlayer::getRecvSpeed() { + return TcpClient::getRecvSpeed() + (_http_ts_player ? _http_ts_player->getRecvSpeed() : 0); +} + +size_t HlsPlayer::getRecvTotalBytes() { + return TcpClient::getRecvTotalBytes() + (_http_ts_player ? _http_ts_player->getRecvTotalBytes() : 0); +} ////////////////////////////////////////////////////////////////////////// void HlsDemuxer::start(const EventPoller::Ptr &poller, TrackListener *listener) { @@ -476,6 +489,7 @@ void HlsPlayerImp::onPacket(const char *data, size_t len) { if (_decoder && _demuxer) { _decoder->input((uint8_t *) data, len); } + _recvtotalbytes += HlsPlayer::getRecvTotalBytes(); } void HlsPlayerImp::addTrackCompleted() { @@ -527,4 +541,11 @@ vector HlsPlayerImp::getTracks(bool ready) const { return static_pointer_cast(_demuxer)->getTracks(ready); } +size_t HlsPlayerImp::getRecvSpeed() { + return PlayerImp::getRecvSpeed(); +} + +size_t HlsPlayerImp::getRecvTotalBytes() { + return _recvtotalbytes; +} }//namespace mediakit diff --git a/src/Http/HlsPlayer.h b/src/Http/HlsPlayer.h index 10373894..3f3d965a 100644 --- a/src/Http/HlsPlayer.h +++ b/src/Http/HlsPlayer.h @@ -49,7 +49,7 @@ private: std::deque > > _frame_cache; }; -class HlsPlayer : public HttpClientImp , public PlayerBase , public HlsParser{ +class HlsPlayer: public HttpClientImp, public PlayerBase, public HlsParser { public: HlsPlayer(const toolkit::EventPoller::Ptr &poller); @@ -73,6 +73,9 @@ public: */ void teardown() override; + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; + protected: /** * 收到ts包 @@ -127,12 +130,21 @@ private: int _timeout_multiple = MIN_TIMEOUT_MULTIPLE; int _try_fetch_index_times = 0; int _ts_download_failed_count = 0; + // RFC 8216 reload interval depends on whether the last media sequence + // changed. We intentionally keep this lightweight and only track the + // sequence number for live playlist refresh timing. + bool _playlist_reload_changed = true; + +protected: + size_t _recvtotalbytes = 0; }; -class HlsPlayerImp : public PlayerImp, private TrackListener { +class HlsPlayerImp final: public PlayerImp, private TrackListener { public: using Ptr = std::shared_ptr; HlsPlayerImp(const toolkit::EventPoller::Ptr &poller = nullptr); + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; private: //// HlsPlayer override//// diff --git a/src/Http/HttpBody.cpp b/src/Http/HttpBody.cpp index b73538c7..f7b0d3d0 100644 --- a/src/Http/HttpBody.cpp +++ b/src/Http/HttpBody.cpp @@ -163,7 +163,7 @@ static std::shared_ptr getSharedMmap(const string &file_path, int64_t &fil if (addr_ == nullptr) { mmap_close(hfile, hmapping, addr_); - WarnL << "MapViewOfFile() " << file_path << " failed:"; + WarnL << "MapViewOfFile() " << file_path << " failed:"; return nullptr; } @@ -194,6 +194,13 @@ static std::shared_ptr getSharedMmap(const string &file_path, int64_t &fil } HttpFileBody::HttpFileBody(const string &file_path, bool use_mmap) { + + // 判断是否为目录,避免对目录进行mmap操作,导致程序崩溃。 + if (File::is_dir(file_path)) { + _read_to = -1; + return; + } + if (use_mmap ) { _map_addr = getSharedMmap(file_path, _read_to); } @@ -289,7 +296,7 @@ Buffer::Ptr HttpFileBody::readData(size_t size) { // Data is read ret->setSize(iRead); _file_offset += iRead; - return std::move(ret); + return ret; } // 读取文件异常,文件真实长度小于声明长度 [AUTO-TRANSLATED:89d09f9b] // File reading exception, the actual length of the file is less than the declared length @@ -378,7 +385,7 @@ string HttpMultiFormBody::multiFormBodyPrefix(const HttpArgs &args, const string << "file" << "\"; filename=\"" << fileName << "\"\r\n"; body << "Content-Type: application/octet-stream\r\n\r\n"; - return std::move(body); + return body; } HttpBufferBody::HttpBufferBody(Buffer::Ptr buffer) { diff --git a/src/Http/HttpClient.cpp b/src/Http/HttpClient.cpp index 23b97b42..4ee3f375 100644 --- a/src/Http/HttpClient.cpp +++ b/src/Http/HttpClient.cpp @@ -18,6 +18,15 @@ using namespace toolkit; namespace mediakit { +static bool connectionContainsClose(const string &connection) { + for (auto token : split(connection, ",")) { + if (strToLower(trim(std::move(token))) == "close") { + return true; + } + } + return false; +} + void HttpClient::sendRequest(const string &url) { clearResponse(); _url = url; @@ -56,26 +65,24 @@ void HttpClient::sendRequest(const string &url) { } auto host_header = host; splitUrl(host, host, port); + auto keep_alive = _request_keep_alive; + auto persistent = _http_persistent && keep_alive; + bool protocol_changed = (_is_https != is_https); + bool host_changed = (_last_host != host + ":" + to_string(port)) || protocol_changed; + _last_host = host + ":" + to_string(port); + _is_https = is_https; + _header.emplace("Host", host_header); _header.emplace("User-Agent", kServerName); _header.emplace("Accept", "*/*"); _header.emplace("Accept-Language", "zh-CN,zh;q=0.8"); - if (_http_persistent) { - _header.emplace("Connection", "keep-alive"); - } else { - _header.emplace("Connection", "close"); - } - _http_persistent = true; + _header.emplace("Connection", keep_alive ? "keep-alive" : "close"); if (_body && _body->remainSize()) { _header.emplace("Content-Length", to_string(_body->remainSize())); GET_CONFIG(string, charSet, Http::kCharSet); _header.emplace("Content-Type", "application/x-www-form-urlencoded; charset=" + charSet); } - bool host_changed = (_last_host != host + ":" + to_string(port)) || (_is_https != is_https); - _last_host = host + ":" + to_string(port); - _is_https = is_https; - auto cookies = HttpCookieStorage::Instance().get(_last_host, _path); _StrPrinter printer; for (auto &cookie : cookies) { @@ -85,13 +92,24 @@ void HttpClient::sendRequest(const string &url) { printer.pop_back(); _header.emplace("Cookie", printer); } - if (!alive() || host_changed || !_http_persistent) { - if (isUsedProxy()) { + if (isUsedProxy()) { + // All proxy traffic uses CONNECT, so reuse is limited to the same tunnel target. + bool proxy_reuse = alive() && persistent && !host_changed && _proxy_connected; + + if (!proxy_reuse) { + _http_persistent = keep_alive; _proxy_connected = false; - startConnect(_proxy_host, _proxy_port, _wait_header_ms / 1000.0f); + startConnectWithProxy(host, _proxy_host, _proxy_port, _wait_header_ms / 1000.0f); } else { - startConnect(host, port, _wait_header_ms / 1000.0f); + SockException ex; + onConnect_l(ex); } + return; + } + + if (!alive() || host_changed || !persistent) { + _http_persistent = keep_alive; + startConnect(host, port, _wait_header_ms / 1000.0f); } else { SockException ex; onConnect_l(ex); @@ -103,6 +121,9 @@ void HttpClient::clear() { _user_set_header.clear(); _body.reset(); _method.clear(); + _request_keep_alive = true; + // Keep transport-level state so a live direct/proxy connection can still + // be reused after the caller resets only the per-request state. clearResponse(); } @@ -182,6 +203,8 @@ void HttpClient::onConnect_l(const SockException &ex) { _path.clear(); } else { printer << "CONNECT " << _last_host << " HTTP/1.1\r\n"; + printer << "Host: " << _last_host << "\r\n"; + printer << "User-Agent: " << kServerName << "\r\n"; printer << "Proxy-Connection: keep-alive\r\n"; if (!_proxy_auth.empty()) { printer << "Proxy-Authorization: Basic " << _proxy_auth << "\r\n"; @@ -198,9 +221,8 @@ void HttpClient::onRecv(const Buffer::Ptr &pBuf) { void HttpClient::onError(const SockException &ex) { if (ex.getErrCode() == Err_reset && _allow_resend_request && _http_persistent && _recved_body_size == 0 && !_header_recved) { - // 连接被重置,可能是服务器主动断开了连接, 或者服务器内核参数或防火墙的持久连接空闲时间超时或不一致. [AUTO-TRANSLATED:8a78f452] - // The connection was reset, possibly because the server actively closed the connection, or the server kernel parameters or firewall's persistent connection idle timeout or inconsistency. - // 如果是持久化连接,那么我们可以通过重连来解决这个问题 [AUTO-TRANSLATED:6c113e17] + // 连接被重置,可能是服务器主动断开了连接, 或者服务器内核参数或防火墙的持久连接空闲时间超时或不一致. + // 如果是持久化连接,那么我们可以通过重连来解决这个问题 // If it is a persistent connection, we can solve this problem by reconnecting // The connection was reset, possibly because the server actively disconnected the connection, // or the persistent connection idle time of the server kernel parameters or firewall timed out or inconsistent. @@ -215,12 +237,17 @@ void HttpClient::onError(const SockException &ex) { ssize_t HttpClient::onRecvHeader(const char *data, size_t len) { _parser.parse(data, len); - if (_parser.status() == "302" || _parser.status() == "301" || _parser.status() == "303") { + auto connection_close = connectionContainsClose(_parser["Connection"]); + if (connection_close) { + _http_persistent = false; + } + if (_parser.status() == "302" || _parser.status() == "301" || _parser.status() == "303" || _parser.status() == "307") { auto new_url = Parser::mergeUrl(_url, _parser["Location"]); if (new_url.empty()) { throw invalid_argument("未找到Location字段(跳转url)"); } - if (onRedirectUrl(new_url, _parser.status() == "302")) { + bool temporary_redirect = _parser.status() == "302" || _parser.status() == "307"; + if (onRedirectUrl(new_url, temporary_redirect)) { HttpClient::sendRequest(new_url); return 0; } @@ -260,7 +287,7 @@ ssize_t HttpClient::onRecvHeader(const char *data, size_t len) { _total_body_size = -1; } - if (_total_body_size == 0) { + if (_total_body_size == 0 || _method == "HEAD") { // 后续没content,本次http请求结束 [AUTO-TRANSLATED:8532172f] // There is no content afterwards, this http request ends onResponseCompleted_l(SockException(Err_success, "The request is successful but has no body")); @@ -463,6 +490,13 @@ void HttpClient::setCompleteTimeout(size_t timeout_ms) { _wait_complete_ms = timeout_ms; } +void HttpClient::setRequestKeepAlive(bool enable) { + _request_keep_alive = enable; + if (!enable) { + _http_persistent = false; + } +} + bool HttpClient::isUsedProxy() const { return _used_proxy; } @@ -472,19 +506,50 @@ bool HttpClient::isProxyConnected() const { } void HttpClient::setProxyUrl(string proxy_url) { + auto old_used_proxy = _used_proxy; + auto old_proxy_host = _proxy_host; + auto old_proxy_port = _proxy_port; + auto old_proxy_auth = _proxy_auth; + _proxy_url = std::move(proxy_url); if (!_proxy_url.empty()) { + _proxy_host.clear(); + _proxy_port = 0; + _proxy_auth.clear(); parseProxyUrl(_proxy_url, _proxy_host, _proxy_port, _proxy_auth); _used_proxy = true; } else { _used_proxy = false; + _proxy_host.clear(); + _proxy_port = 0; + _proxy_auth.clear(); + _proxy_connected = false; + } + + auto proxy_config_changed = old_used_proxy != _used_proxy + || old_proxy_host != _proxy_host + || old_proxy_port != _proxy_port + || old_proxy_auth != _proxy_auth; + if (proxy_config_changed) { + // A proxy mode or endpoint change must not reuse the previous transport. + _http_persistent = false; + _proxy_connected = false; } } bool HttpClient::checkProxyConnected(const char *data, size_t len) { - auto ret = strstr(data, "HTTP/1.1 200 Connection established"); - _proxy_connected = ret != nullptr; - return _proxy_connected; + string response(data, len); + if (response.find("HTTP/1.1 200") != string::npos || response.find("HTTP/1.0 200") != string::npos) { + _proxy_connected = true; + return true; + } + + _proxy_connected = false; + // CONNECT failed, which usually means the proxy rejected the tunnel request, + // does not support CONNECT for this target, or the proxy authentication is invalid. + WarnL << "proxy CONNECT failed, status line: " + << response.substr(0, response.find("\r\n")); + return false; } void HttpClient::setAllowResendRequest(bool allow) { diff --git a/src/Http/HttpClient.h b/src/Http/HttpClient.h index c7ca2433..caaa7a03 100644 --- a/src/Http/HttpClient.h +++ b/src/Http/HttpClient.h @@ -60,8 +60,8 @@ public: virtual void sendRequest(const std::string &url); /** - * 重置对象 - * Reset object + * 重置当前请求相关状态,并尽量保留可复用的传输状态 + * Reset per-request state while preserving reusable transport state when possible * [AUTO-TRANSLATED:d23b5bbb] */ @@ -168,10 +168,10 @@ public: void setHeaderTimeout(size_t timeout_ms); /** - * 设置接收body数据超时时间, 默认5秒 + * 设置接收body数据超时时间, 默认10秒 * 此参数可以用于处理超大body回复的超时问题 * 此参数可以等于0 - * Set the timeout for receiving body data, default 5 seconds + * Set the timeout for receiving body data, default 10 seconds * This parameter can be used to handle timeout issues for large body responses * This parameter can be equal to 0 @@ -189,6 +189,14 @@ public: */ void setCompleteTimeout(size_t timeout_ms); + /** + * 设置请求头中的 keep-alive 语义,默认启用 + * Set whether requests should advertise keep-alive semantics, enabled by default + + * [AUTO-TRANSLATED:6f62f63c] + */ + void setRequestKeepAlive(bool enable); + /** * 设置http代理url * Set http proxy url @@ -276,7 +284,6 @@ private: void onResponseCompleted_l(const toolkit::SockException &ex); void onConnect_l(const toolkit::SockException &ex); void checkCookie(HttpHeader &headers); - private: //for http response bool _complete = false; @@ -289,7 +296,7 @@ private: std::shared_ptr _chunked_splitter; //for request args - bool _is_https; + bool _is_https = false; std::string _url; HttpHeader _user_set_header; HttpBody::Ptr _body; @@ -308,9 +315,10 @@ private: toolkit::Ticker _wait_body; toolkit::Ticker _wait_complete; + bool _request_keep_alive = true; bool _used_proxy = false; bool _proxy_connected = false; - uint16_t _proxy_port; + uint16_t _proxy_port = 0; std::string _proxy_url; std::string _proxy_host; std::string _proxy_auth; diff --git a/src/Http/HttpCookieManager.cpp b/src/Http/HttpCookieManager.cpp index 1b3d4c35..7fbae80b 100644 --- a/src/Http/HttpCookieManager.cpp +++ b/src/Http/HttpCookieManager.cpp @@ -61,6 +61,11 @@ bool HttpServerCookie::isExpired() { return _ticker.elapsedTime() > _max_elapsed * 1000; } +void HttpServerCookie::setExpired() { + _ticker.resetTime(); + _max_elapsed = 0; +} + void HttpServerCookie::setAttach(toolkit::Any attach) { _attach = std::move(attach); } diff --git a/src/Http/HttpCookieManager.h b/src/Http/HttpCookieManager.h index a0e29a4e..64868302 100644 --- a/src/Http/HttpCookieManager.h +++ b/src/Http/HttpCookieManager.h @@ -118,6 +118,11 @@ public: */ bool isExpired(); + /** + * 使cookie过期作废 + */ + void setExpired(); + /** * 设置附加数据 * Set additional data @@ -128,7 +133,6 @@ public: /* * 获取附加数据 - /* * Get additional data * [AUTO-TRANSLATED:e277d75d] diff --git a/src/Http/HttpFileManager.cpp b/src/Http/HttpFileManager.cpp index c9701d43..a659cadf 100644 --- a/src/Http/HttpFileManager.cpp +++ b/src/Http/HttpFileManager.cpp @@ -17,6 +17,7 @@ #include "HttpConst.h" #include "HttpSession.h" #include "HttpFileManager.h" +#include "Common/MultiMediaSourceMuxer.h" using namespace std; using namespace toolkit; @@ -31,7 +32,7 @@ namespace mediakit { // If the player does not access the cookie within 60 seconds, the hls playback authentication will be triggered again. static size_t kHlsCookieSecond = 60; static size_t kFindSrcIntervalSecond = 3; -static const string kCookieName = "ZL_COOKIE"; +static const string kCookieName = "ZLM_HTTP_COOKIE"; static const string kHlsSuffix = "/hls.m3u8"; static const string kHlsFMP4Suffix = "/hls.fmp4.m3u8"; @@ -48,6 +49,8 @@ struct HttpCookieAttachment { // 上次鉴权失败信息,为空则上次鉴权成功 [AUTO-TRANSLATED:de48b753] // Last authentication failure information, empty means last authentication succeeded string _err_msg; + // hls文件根目录 + string _hls_root_path; // hls直播时的其他一些信息,主要用于播放器个数计数以及流量计数 [AUTO-TRANSLATED:790de53a] // Other information during hls live broadcast, mainly used for player count and traffic count HlsCookieData::Ptr _hls_data; @@ -315,36 +318,9 @@ static bool emitHlsPlayed(const Parser &parser, const MediaInfo &media_info, con return flag; } -class SockInfoImp : public SockInfo{ -public: - using Ptr = std::shared_ptr; - - string get_local_ip() override { - return _local_ip; - } - - uint16_t get_local_port() override { - return _local_port; - } - - string get_peer_ip() override { - return _peer_ip; - } - - uint16_t get_peer_port() override { - return _peer_port; - } - - string getIdentifier() const override { - return _identifier; - } - - string _local_ip; - string _peer_ip; - string _identifier; - uint16_t _local_port; - uint16_t _peer_port; -}; +static std::string getUidFromParams(const string ¶ms) { + return params; +} /** * 判断http客户端是否有权限访问文件的逻辑步骤 @@ -362,11 +338,11 @@ public: * [AUTO-TRANSLATED:dfc0f15f] */ -static void canAccessPath(Session &sender, const Parser &parser, const MediaInfo &media_info, bool is_dir, +static void canAccessPath(Session &sender, const Parser &parser, const MediaInfo &media_info, const std::string &file_path, bool is_dir, const function &callback) { // 获取用户唯一id [AUTO-TRANSLATED:5b1cf4bf] // Get the user's unique id - auto uid = parser.params(); + auto uid = getUidFromParams(parser.params()); auto path = parser.url(); // 先根据http头中的cookie字段获取cookie [AUTO-TRANSLATED:155cf682] @@ -401,7 +377,7 @@ static void canAccessPath(Session &sender, const Parser &parser, const MediaInfo } // 上次鉴权失败,但是如果url参数发生变更,那么也重新鉴权下 [AUTO-TRANSLATED:df9bd345] // Last authentication failed, but if the url parameter changes, then re-authenticate - if (parser.params().empty() || parser.params() == cookie->getUid()) { + if (parser.params().empty() || getUidFromParams(parser.params()) == cookie->getUid()) { // url参数未变,或者本来就没有url参数,那么判断本次请求为重复请求,无访问权限 [AUTO-TRANSLATED:f46b4fca] // The url parameter has not changed, or there is no url parameter at all, then determine that the current request is a duplicate request and has no access permission callback(attach._err_msg, update_cookie ? cookie : nullptr); @@ -415,17 +391,18 @@ static void canAccessPath(Session &sender, const Parser &parser, const MediaInfo bool is_hls = media_info.schema == HLS_SCHEMA || media_info.schema == HLS_FMP4_SCHEMA; - SockInfoImp::Ptr info = std::make_shared(); - info->_identifier = sender.getIdentifier(); - info->_peer_ip = sender.get_peer_ip(); - info->_peer_port = sender.get_peer_port(); - info->_local_ip = sender.get_local_ip(); - info->_local_port = sender.get_local_port(); + weak_ptr weak_session = static_pointer_cast(sender.shared_from_this()); // 该用户从来未获取过cookie,这个时候我们广播是否允许该用户访问该http目录 [AUTO-TRANSLATED:8f4b3dd2] // This user has never obtained a cookie, at this time we broadcast whether to allow this user to access this http directory - HttpSession::HttpAccessPathInvoker accessPathInvoker = [callback, uid, path, is_dir, is_hls, media_info, info] + HttpSession::HttpAccessPathInvoker accessPathInvoker = [callback, uid, path, is_dir, is_hls, media_info, weak_session] (const string &err_msg, const string &cookie_path_in, int life_second) { + auto strong_session = weak_session.lock(); + if (!strong_session) { + // http客户端已经断开,不需要回复 [AUTO-TRANSLATED:9a252e21] + // The http client has disconnected and does not need to reply + return; + } HttpServerCookie::Ptr cookie; if (life_second) { // 本次鉴权设置了有效期,我们把鉴权结果缓存在cookie中 [AUTO-TRANSLATED:5a12f48e] @@ -447,11 +424,11 @@ static void canAccessPath(Session &sender, const Parser &parser, const MediaInfo if (is_hls) { // hls相关信息 [AUTO-TRANSLATED:37893a71] // hls related information - attach->_hls_data = std::make_shared(media_info, info); + attach->_hls_data = std::make_shared(media_info, strong_session); } - toolkit::Any any; - any.set(std::move(attach)); - callback(err_msg, HttpCookieManager::Instance().addCookie(kCookieName, uid, life_second, std::move(any))); + toolkit::Any any; + any.set(std::move(attach)); + callback(err_msg, HttpCookieManager::Instance().addCookie(kCookieName, uid, life_second, std::move(any))); } else { callback(err_msg, nullptr); } @@ -466,7 +443,7 @@ static void canAccessPath(Session &sender, const Parser &parser, const MediaInfo // 事件未被拦截,则认为是http下载请求 [AUTO-TRANSLATED:7d449ccc] // The event was not intercepted, it is considered an http download request - bool flag = NOTICE_EMIT(BroadcastHttpAccessArgs, Broadcast::kBroadcastHttpAccess, parser, path, is_dir, accessPathInvoker, sender); + bool flag = NOTICE_EMIT(BroadcastHttpAccessArgs, Broadcast::kBroadcastHttpAccess, parser, path, file_path, is_dir, accessPathInvoker, sender); if (!flag) { // 此事件无人监听,我们默认都有权限访问 [AUTO-TRANSLATED:e1524c0f] // No one is listening to this event, we assume that everyone has permission to access it by default @@ -498,6 +475,8 @@ static string pathCat(const string &a, const string &b){ return a + '/' + b; } +static string getFilePath(const Parser &parser,const MediaInfo &media_info, Session *sender, const string &customRootPath = ""); + /** * 访问文件 * @param sender 事件触发者 @@ -511,17 +490,11 @@ static string pathCat(const string &a, const string &b){ * @param media_info http url information * @param file_path Absolute file path * @param cb Callback object - + * [AUTO-TRANSLATED:2d840fe6] */ static void accessFile(Session &sender, const Parser &parser, const MediaInfo &media_info, const string &file_path, const HttpFileManager::invoker &cb) { bool is_hls = end_with(file_path, kHlsSuffix) || end_with(file_path, kHlsFMP4Suffix); - if (!is_hls && !File::fileExist(file_path)) { - // 文件不存在且不是hls,那么直接返回404 [AUTO-TRANSLATED:7aae578b] - // The file does not exist and is not hls, so directly return 404 - sendNotFound(cb); - return; - } if (is_hls) { // hls,那么移除掉后缀获取真实的stream_id并且修改协议为HLS [AUTO-TRANSLATED:94b5818a] // hls, then remove the suffix to get the real stream_id and change the protocol to HLS @@ -537,7 +510,7 @@ static void accessFile(Session &sender, const Parser &parser, const MediaInfo &m weak_ptr weakSession = static_pointer_cast(sender.shared_from_this()); // 判断是否有权限访问该文件 [AUTO-TRANSLATED:b7f595f5] // Determine whether you have permission to access this file - canAccessPath(sender, parser, media_info, false, [cb, file_path, parser, is_hls, media_info, weakSession](const string &err_msg, const HttpServerCookie::Ptr &cookie) { + canAccessPath(sender, parser, media_info, file_path, false, [cb, file_path, parser, is_hls, media_info, weakSession](const string &err_msg, const HttpServerCookie::Ptr &cookie) { auto strongSession = weakSession.lock(); if (!strongSession) { // http客户端已经断开,不需要回复 [AUTO-TRANSLATED:9a252e21] @@ -582,6 +555,13 @@ static void accessFile(Session &sender, const Parser &parser, const MediaInfo &m invoker.responseFile(parser.getHeader(), httpHeader, file_content.empty() ? file_path : file_content, !is_hls && !is_forbid_cache, file_content.empty()); }; + if (cookie) { + auto &attach = cookie->getAttach(); + if (!attach._hls_root_path.empty()) { + // 强制替换为真实hls路径 + const_cast(file_path) = getFilePath(parser, media_info, nullptr, attach._hls_root_path); + } + } if (!is_hls || !cookie) { // 不是hls或访问m3u8文件不带cookie, 直接回复文件或404 [AUTO-TRANSLATED:64e5d19b] // Not hls or accessing m3u8 files without cookies, directly reply to the file or 404 @@ -631,6 +611,10 @@ static void accessFile(Session &sender, const Parser &parser, const MediaInfo &m // Reset the MediaSource search timer attach._find_src_ticker.resetTime(); + auto muxer = hls->getMuxer(); + if (muxer) { + attach._hls_root_path = muxer->getOption().hls_save_path; + } // m3u8文件可能不存在, 等待m3u8索引文件按需生成 [AUTO-TRANSLATED:0dbd4df2] // The m3u8 file may not exist, wait for the m3u8 index file to be generated on demand hls->getIndexFile([response_file, file_path, cookie, cb, parser](const string &file) { @@ -640,13 +624,15 @@ static void accessFile(Session &sender, const Parser &parser, const MediaInfo &m }); } -static string getFilePath(const Parser &parser,const MediaInfo &media_info, Session &sender) { +static string getFilePath(const Parser &parser,const MediaInfo &media_info, Session *sender, const string &customRootPath) { GET_CONFIG(bool, enableVhost, General::kEnableVhost); - GET_CONFIG(string, rootPath, Http::kRootPath); + GET_CONFIG(string, httpRootPath, Http::kRootPath); GET_CONFIG_FUNC(StrCaseMap, virtualPathMap, Http::kVirtualPath, [](const string &str) { return Parser::parseArgs(str, ";", ","); }); + auto rootPath = customRootPath.empty() ? httpRootPath : customRootPath; + string url, path, virtual_app; auto it = virtualPathMap.find(media_info.app); if (it != virtualPathMap.end()) { @@ -675,10 +661,12 @@ static string getFilePath(const Parser &parser,const MediaInfo &media_info, Sess // The accessed http file must not be outside the http root directory throw std::runtime_error("Attempting to access files outside of the http root directory"); } - // 替换url,防止返回的目录索引网页被注入非法内容 [AUTO-TRANSLATED:463ad1b1] - // Replace the url to prevent the returned directory index page from being injected with illegal content - const_cast(parser).setUrl("/" + virtual_app + ret.substr(http_root.size())); - NOTICE_EMIT(BroadcastHttpBeforeAccessArgs, Broadcast::kBroadcastHttpBeforeAccess, parser, ret, sender); + if (sender) { + // 替换url,防止返回的目录索引网页被注入非法内容 [AUTO-TRANSLATED:463ad1b1] + // Replace the url to prevent the returned directory index page from being injected with illegal content + const_cast(parser).setUrl("/" + virtual_app + ret.substr(http_root.size())); + NOTICE_EMIT(BroadcastHttpBeforeAccessArgs, Broadcast::kBroadcastHttpBeforeAccess, parser, ret, *sender); + } return ret; } @@ -691,14 +679,14 @@ static string getFilePath(const Parser &parser,const MediaInfo &media_info, Sess * @param sender Event trigger * @param parser http request * @param cb Callback object - + * [AUTO-TRANSLATED:a79c824d] */ void HttpFileManager::onAccessPath(Session &sender, Parser &parser, const HttpFileManager::invoker &cb) { auto fullUrl = "http://" + parser["Host"] + parser.fullUrl(); MediaInfo media_info(fullUrl); - auto file_path = getFilePath(parser, media_info, sender); - if (file_path.size() == 0) { + auto file_path = getFilePath(parser, media_info, &sender); + if (file_path.empty()) { sendNotFound(cb); return; } @@ -729,7 +717,7 @@ void HttpFileManager::onAccessPath(Session &sender, Parser &parser, const HttpFi } // 判断是否有权限访问该目录 [AUTO-TRANSLATED:963d02a6] // Determine if there is permission to access this directory - canAccessPath(sender, parser, media_info, true, [strMenu, cb](const string &err_msg, const HttpServerCookie::Ptr &cookie) mutable{ + canAccessPath(sender, parser, media_info, file_path, true, [strMenu, cb](const string &err_msg, const HttpServerCookie::Ptr &cookie) mutable{ if (!err_msg.empty()) { strMenu = err_msg; } diff --git a/src/Http/HttpSession.cpp b/src/Http/HttpSession.cpp index 0b89d182..cc4dc0e9 100644 --- a/src/Http/HttpSession.cpp +++ b/src/Http/HttpSession.cpp @@ -61,6 +61,7 @@ ssize_t HttpSession::onRecvHeader(const char *header, size_t len) { static onceToken token([]() { s_func_map.emplace("GET", &HttpSession::onHttpRequest_GET); s_func_map.emplace("POST", &HttpSession::onHttpRequest_POST); + s_func_map.emplace("PUT", &HttpSession::onHttpRequest_POST); // DELETE命令用于whip/whep用,只用于触发http api [AUTO-TRANSLATED:f3b7aaea] // DELETE command is used for whip/whep, only used to trigger http api s_func_map.emplace("DELETE", &HttpSession::onHttpRequest_POST); @@ -213,6 +214,7 @@ bool HttpSession::checkWebSocket() { if (Sec_WebSocket_Key.empty()) { return false; } + _is_websocket = true; auto Sec_WebSocket_Accept = encodeBase64(SHA1::encode_bin(Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); KeyValue headerOut; @@ -223,22 +225,22 @@ bool HttpSession::checkWebSocket() { headerOut["Sec-WebSocket-Protocol"] = _parser["Sec-WebSocket-Protocol"]; } - auto res_cb = [this, headerOut]() { - _live_over_websocket = true; - sendResponse(101, false, nullptr, headerOut, nullptr, true); + auto res_cb = []() { + // 改成先回复http头模式,以解决按需播放场景下websocket请求pending问题:#4553 }; - auto res_cb_flv = [this, headerOut]() mutable { - _live_over_websocket = true; + auto res_immediately = [this, headerOut]() mutable { headerOut.emplace("Cache-Control", "no-store"); sendResponse(101, false, nullptr, headerOut, nullptr, true); + _live_over_websocket = true; }; // 判断是否为websocket-flv [AUTO-TRANSLATED:31682d7a] // Determine whether it is websocket-flv - if (checkLiveStreamFlv(res_cb_flv)) { + if (checkLiveStreamFlv(res_cb)) { // 这里是websocket-flv直播请求 [AUTO-TRANSLATED:4bea5956] // This is a websocket-flv live request + res_immediately(); return true; } @@ -247,6 +249,7 @@ bool HttpSession::checkWebSocket() { if (checkLiveStreamTS(res_cb)) { // 这里是websocket-ts直播请求 [AUTO-TRANSLATED:8ab08dd6] // This is a websocket-ts live request + res_immediately(); return true; } @@ -255,6 +258,7 @@ bool HttpSession::checkWebSocket() { if (checkLiveStreamFMP4(res_cb)) { // 这里是websocket-fmp4直播请求 [AUTO-TRANSLATED:ccf0c1e2] // This is a websocket-fmp4 live request + res_immediately(); return true; } @@ -305,6 +309,12 @@ bool HttpSession::checkLiveStream(const string &schema, const string &url_suffix return false; } + if (_is_websocket) { + _media_info.protocol = overSsl() ? "wss" : "ws"; + } else { + _media_info.protocol = overSsl() ? "https" : "http"; + } + bool close_flag = !strcasecmp(_parser["Connection"].data(), "close"); weak_ptr weak_self = static_pointer_cast(shared_from_this()); @@ -349,7 +359,7 @@ bool HttpSession::checkLiveStream(const string &schema, const string &url_suffix Broadcast::AuthInvoker invoker = [weak_self, onRes](const string &err) { if (auto strong_self = weak_self.lock()) { - strong_self->async([onRes, err]() { onRes(err); }); + strong_self->async([onRes, err]() { onRes(err); }, false); } }; @@ -357,7 +367,7 @@ bool HttpSession::checkLiveStream(const string &schema, const string &url_suffix if (!flag) { // 该事件无人监听,默认不鉴权 [AUTO-TRANSLATED:e1fbc6ae] // No one is listening to this event, no authentication by default - onRes(""); + invoker(""); } return true; } @@ -387,7 +397,7 @@ bool HttpSession::checkLiveStreamFMP4(const function &cb) { _fmp4_reader = fmp4_src->getRing()->attach(getPoller()); _fmp4_reader->setGetInfoCB([weak_self]() { Any ret; - ret.set(static_pointer_cast(weak_self.lock())); + ret.set(static_pointer_cast(weak_self.lock())); return ret; }); _fmp4_reader->setDetachCB([weak_self]() { @@ -437,7 +447,7 @@ bool HttpSession::checkLiveStreamTS(const function &cb) { _ts_reader = ts_src->getRing()->attach(getPoller()); _ts_reader->setGetInfoCB([weak_self]() { Any ret; - ret.set(static_pointer_cast(weak_self.lock())); + ret.set(static_pointer_cast(weak_self.lock())); return ret; }); _ts_reader->setDetachCB([weak_self]() { @@ -652,6 +662,23 @@ void HttpSession::sendResponse(int code, const HttpSession::KeyValue &header, const HttpBody::Ptr &body, bool no_content_length) { + if (_live_over_websocket) { + WebSocketHeader ws_header; + ws_header._fin = true; + ws_header._reserved = 0; + ws_header._opcode = WebSocketHeader::CLOSE; + ws_header._mask_flag = false; + uint16_t why = htons(0xFFFF & code); + std::string buffer; + buffer.append(reinterpret_cast(&why), 2); + if (body && code != 404) { + buffer.append(body->readData(body->remainSize())->toString()); + } else { + buffer.append("unknown reason"); + } + WebSocketSplitter::encode(ws_header, std::make_shared(std::move(buffer))); + return; + } GET_CONFIG(string, charSet, Http::kCharSet); GET_CONFIG(uint32_t, keepAliveSec, Http::kKeepAliveSecond); diff --git a/src/Http/HttpSession.h b/src/Http/HttpSession.h index 27544318..b714e4b6 100644 --- a/src/Http/HttpSession.h +++ b/src/Http/HttpSession.h @@ -158,6 +158,7 @@ protected: private: bool _is_live_stream = false; bool _live_over_websocket = false; + bool _is_websocket = false; // 超时时间 [AUTO-TRANSLATED:f15e2672] // Timeout size_t _keep_alive_sec = 0; diff --git a/src/Http/TsPlayer.cpp b/src/Http/TsPlayer.cpp index 339064ec..f8fdb0ab 100644 --- a/src/Http/TsPlayer.cpp +++ b/src/Http/TsPlayer.cpp @@ -25,6 +25,7 @@ void TsPlayer::play(const string &url) { setHeaderTimeout((*this)[Client::kTimeoutMS].as()); setBodyTimeout((*this)[Client::kMediaTimeoutMS].as()); setMethod("GET"); + addCustomHeader(this); sendRequest(url); } @@ -57,4 +58,11 @@ void TsPlayer::onResponseBody(const char *buf, size_t size) { } } +size_t TsPlayer::getRecvSpeed() { + return TcpClient::getRecvSpeed(); +} + +size_t TsPlayer::getRecvTotalBytes() { + return TcpClient::getRecvTotalBytes(); +} } // namespace mediakit \ No newline at end of file diff --git a/src/Http/TsPlayer.h b/src/Http/TsPlayer.h index 4674f688..270cc22d 100644 --- a/src/Http/TsPlayer.h +++ b/src/Http/TsPlayer.h @@ -28,6 +28,9 @@ public: */ void play(const std::string &url) override; + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; + /** * 停止播放 * Stop playing diff --git a/src/Http/TsPlayerImp.h b/src/Http/TsPlayerImp.h index 821fffb3..ff02ccbb 100644 --- a/src/Http/TsPlayerImp.h +++ b/src/Http/TsPlayerImp.h @@ -21,6 +21,8 @@ public: using Ptr = std::shared_ptr; TsPlayerImp(const toolkit::EventPoller::Ptr &poller = nullptr); + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; private: //// TsPlayer override//// diff --git a/src/Http/TsplayerImp.cpp b/src/Http/TsplayerImp.cpp index 42df6a5a..0b99d115 100644 --- a/src/Http/TsplayerImp.cpp +++ b/src/Http/TsplayerImp.cpp @@ -46,6 +46,10 @@ void TsPlayerImp::onPlayResult(const SockException &ex) { } void TsPlayerImp::onShutdown(const SockException &ex) { + if (!ex) { + // http-ts拉流,如果为eof正常断开,那么强制为异常状态 + const_cast(ex).reset(Err_other, ex.what()); + } while (_demuxer) { try { // shared_from_this()可能抛异常 [AUTO-TRANSLATED:6af9bd3c] @@ -77,4 +81,11 @@ vector TsPlayerImp::getTracks(bool ready) const { return static_pointer_cast(_demuxer)->getTracks(ready); } +size_t TsPlayerImp::getRecvSpeed() { + return TcpClient::getRecvSpeed(); +} + +size_t TsPlayerImp::getRecvTotalBytes() { + return TcpClient::getRecvTotalBytes(); +} }//namespace mediakit \ No newline at end of file diff --git a/src/Onvif/Onvif.cpp b/src/Onvif/Onvif.cpp new file mode 100644 index 00000000..35f49d6a --- /dev/null +++ b/src/Onvif/Onvif.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "Onvif.h" +#include "Util/util.h" +#include "Util/onceToken.h" +#include "pugixml.hpp" +#include "SoapUtil.h" +#include "Common/config.h" +#include "Common/MediaSource.h" + +using namespace std; +using namespace toolkit; +using namespace mediakit; + +////////////Rtp代理相关配置/////////// +namespace Onvif { +#define ONVIF_FIELD "onvif." +const string kPort = ONVIF_FIELD"port"; +static onceToken token([]() { + mINI::Instance()[kPort] = 3702; +}); + +} //namespace RtpProxy + +bool OnvifSearcher::onDeviceCB::operator()(std::map &device_info, const std::string &onvif_url) { + if (expired()) { + //超时 + cb = nullptr; + return false; + } + if (!cb) { + return false; + } + return cb(device_info, onvif_url); +} + +bool OnvifSearcher::onDeviceCB::expired() const { + return ticker.elapsedTime() > timeout_ms; +} + +///////////////////////////////////////////////////////////////////////////////////// + +INSTANCE_IMP(OnvifSearcher) + +OnvifSearcher::OnvifSearcher() { + _poller = EventPollerPool::Instance().getPoller(); +} + +void OnvifSearcher::sendSearchBroadcast(std::string subnet_prefix, onDevice cb, uint64_t timeout_ms) { + weak_ptr weak_self = shared_from_this(); + _poller->async([weak_self, cb, timeout_ms, subnet_prefix]() mutable { + auto strong_self = weak_self.lock(); + if (strong_self) { + strong_self->sendSearchBroadcast_l(move(subnet_prefix), std::move(cb), timeout_ms); + } + }); +} + +void OnvifSearcher::sendSearchBroadcast_l(const std::string &subnet_prefix, onDevice cb, uint64_t timeout_ms) { + static struct sockaddr_in s_search_address; + static onceToken s_token([]() { + s_search_address.sin_family = AF_INET; + s_search_address.sin_port = htons(3702); + s_search_address.sin_addr.s_addr = inet_addr("239.255.255.250"); + bzero(&(s_search_address.sin_zero), sizeof s_search_address.sin_zero); + }); + + GET_CONFIG(uint16_t, onvif_port, Onvif::kPort); + if (_sock_list.empty()) { + for (auto &network : SockUtil::getInterfaceList()) { + auto sock = Socket::createSocket(_poller, false); + sock->bindUdpSock(onvif_port, network["ip"]); + SockUtil::setBroadcast(sock->rawFD()); + weak_ptr weak_self = shared_from_this(); + sock->setOnRead([weak_self](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) { + auto strong_self = weak_self.lock(); + if (strong_self) { + strong_self->onDeviceResponse(buf, addr, addr_len); + } + }); + _sock_list.emplace_back(std::move(sock)); + } + } + + if (!_timer) { + weak_ptr weak_self = shared_from_this(); + _timer = std::make_shared(1, [weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return false; + } + for (auto it = strong_self->_cb_map.begin(); it != strong_self->_cb_map.end();) { + if (it->second.expired()) { + it = strong_self->_cb_map.erase(it); + continue; + } + ++it; + } + return true; + }, _poller); + } + + auto uuid = SoapUtil::createUuidString(); + auto xml = SoapUtil::createDiscoveryString(uuid); + auto &ref = _cb_map[uuid]; + ref.cb = std::move(cb); + ref.timeout_ms = timeout_ms; + std::string ip; + struct sockaddr_in target {}; + + for (auto &sock : _sock_list) { + sock->send(xml, (struct sockaddr *)&s_search_address, sizeof(s_search_address)); + if (!subnet_prefix.empty()) { + for (int i = 1; i <= 254; ++i) { + ip = subnet_prefix + "." + std::to_string(i); + target.sin_family = AF_INET; + target.sin_port = htons(3702); + inet_pton(AF_INET, ip.c_str(), &target.sin_addr); + sock->send(xml, (struct sockaddr *)&target, sizeof(target)); + } + } + } +} + +std::map getDeviceInfo(const string &scopes) { + std::map keys = {{"onvif://www.onvif.org/location", "location"}, + {"onvif://www.onvif.org/name", "name"}, + {"onvif://www.onvif.org/hardware", "hardware"}}; + std::map ret; + auto vec = split(scopes, " "); + for (auto &item : vec) { + string key; + for (auto &pr : keys) { + if (start_with(item, pr.first)) { + key = pr.second; + break; + } + } + if (!key.empty()) { + auto pos = item.rfind('/'); + ret.emplace(key, item.substr(pos + 1)); + } + } + return ret; +} + +void OnvifSearcher::onDeviceResponse(const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) { + try { + SoapObject root; + root.load(buf->data(), buf->size()); + auto uuid = root["Envelope/Header/RelatesTo"]; + auto device_service = root["Envelope/Body/ProbeMatches/ProbeMatch/XAddrs"]; + auto scopes = root["Envelope/Body/ProbeMatches/ProbeMatch/Scopes"]; + auto map = getDeviceInfo(scopes.as_xml().text().as_string()); + onGotDevice(uuid.as_xml().text().as_string(), map, device_service.as_xml().text().as_string()); + } catch (std::exception &ex) { + WarnL << ex.what(); + } +} + +static string getIpv4Url(const std::string &onvif_url) { + auto urls = split(onvif_url, " "); + if (urls.size() > 1) { + for (auto url : urls) { + MediaInfo info(url); + if (isIP(info.host.data())) { + return url; + } + } + } + return onvif_url; +} + +void OnvifSearcher::onGotDevice(const std::string &uuid, std::map &device_info, + const std::string &onvif_url) { + auto it = _cb_map.find(uuid); + if (it == _cb_map.end()) { + WarnL << uuid << " " << onvif_url << " " << device_info["location"] << " " << device_info["name"] << " " + << device_info["hardware"]; + return; + } + auto flag = it->second(device_info, getIpv4Url(onvif_url)); + if (!flag) { + _cb_map.erase(it); + } +} \ No newline at end of file diff --git a/src/Onvif/Onvif.h b/src/Onvif/Onvif.h new file mode 100644 index 00000000..4e77a694 --- /dev/null +++ b/src/Onvif/Onvif.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_ONVIF_H +#define ZLMEDIAKIT_ONVIF_H + +#include +#include +#include "Network/Socket.h" +#include "Network/Buffer.h" + +class OnvifSearcher : public std::enable_shared_from_this { +public: + //返回false代表不再监听事件 + using onDevice = std::function &device_info, const std::string &onvif_url)>; + OnvifSearcher(); + + static OnvifSearcher &Instance(); + void sendSearchBroadcast(std::string subnet_prefix, onDevice cb = nullptr, uint64_t timeout_ms = 10 * 1000); + +private: + void onDeviceResponse(const toolkit::Buffer::Ptr &buf, struct sockaddr *addr, int addr_len); + void onGotDevice(const std::string &uuid, std::map &device_info, const std::string &onvif_url); + void sendSearchBroadcast_l(const std::string &subnet_prefix, onDevice cb, uint64_t timeout_ms); + +private: + struct onDeviceCB{ + onDevice cb; + toolkit::Ticker ticker; + uint64_t timeout_ms; + + bool expired() const; + bool operator()(std::map &device_info, const std::string &onvif_url); + }; + +private: + toolkit::EventPoller::Ptr _poller; + toolkit::Timer::Ptr _timer; + std::vector _sock_list; + std::unordered_map _cb_map; +}; + +#endif //ZLMEDIAKIT_ONVIF_H diff --git a/src/Onvif/SoapUtil.cpp b/src/Onvif/SoapUtil.cpp new file mode 100644 index 00000000..cbc123e8 --- /dev/null +++ b/src/Onvif/SoapUtil.cpp @@ -0,0 +1,491 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include +#include "SoapUtil.h" +#include "Util/util.h" +#include "Util/SHA1.h" +#include "Util/logger.h" +#include "Util/base64.h" +#include "Util/onceToken.h" +#include "Http/HttpRequester.h" +#include "Rtsp/Rtsp.h" + +using namespace std; +using namespace toolkit; +using namespace mediakit; + +static pugi::xml_node find_node(const pugi::xml_node &parent, const std::string &end_str) { + auto ret = parent.find_child([&](const pugi::xml_node &node) { + auto len = strlen(node.name()); + if (len < end_str.size()) { + return false; + } + if (end_str == node.name()) { + return true; + } + if (*(node.name() + len - end_str.size() - 1) != ':') { + return false; + } + return strcasecmp(node.name() + len - end_str.size(), end_str.data()) == 0; + }); + return ret; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +SoapObject::SoapObject() { + _root = std::make_shared(); +} + +SoapObject::SoapObject(std::shared_ptr node) { + _root = std::move(node); +} + +void SoapObject::load(const char *data, size_t len) { + auto doc = std::make_shared(); + auto result = doc->load_string(data, len); + if (!result) { + throw std::invalid_argument(string("解析xml失败:") + result.description()); + } + _root = std::move(doc); +} + +SoapObject::operator bool() const { + return !_root->empty(); +} + +SoapObject SoapObject::operator[](const string &path) const{ + auto hit = *_root; + auto node_name = split(path, "/"); + for (auto &node : node_name) { + hit = find_node(hit, node); + if (hit.empty()) { + return SoapObject(); + } + } + auto ref = _root; + shared_ptr node(new pugi::xml_node(std::move(hit)), [ref](pugi::xml_node *ptr) { + delete ptr; + }); + return SoapObject(std::move(node)); +} + +SoapObject SoapObject::operator[](size_t index) const { + for (auto &hit : *_root) { + if (index-- == 0) { + auto ref = _root; + shared_ptr node(new pugi::xml_node(hit), [ref](pugi::xml_node *ptr) { + delete ptr; + }); + return SoapObject(node); + } + } + return SoapObject(); +} + +std::string SoapObject::as_string() const { + if (!(bool) (*this)) { + return ""; + } + xml_string_writer writer; + _root->print(writer); + return writer.result; +} + +pugi::xml_node SoapObject::as_xml() const { + return *_root; +} + +SoapObject::SoapObject(const pugi::xml_node &node, const SoapObject &ref) { + auto root = ref._root; + _root.reset(new pugi::xml_node(node.internal_object()), [root](pugi::xml_node *ptr) { + delete ptr; + }); +} + +std::string SoapUtil::createDiscoveryString(const string &uuid_in) { + auto uuid = uuid_in; + if (uuid.empty()) { + uuid = createUuidString(); + } + static constexpr char str_fmt[] = "\n" + "\n" + " \n" + " %s\n" + " urn:schemas-xmlsoap-org:ws:2005:04:discovery\n" + " http://schemas.xmlsoap.org/ws/2005/04/discovery/Probe\n" + " \n" + " \n" + " \n" + " dn:NetworkVideoTransmitter\n" + " \n" + " \n" + ""; + return print_to_string(str_fmt, uuid.data()); +} + +static std::string creatString4() { + std::mt19937 rng(std::random_device{}()); + string ret = StrPrinter << std::hex << std::uppercase << std::setfill('0') << ((1 + rng()) & 0xFFFF); + return ret; +} + +std::string SoapUtil::createUuidString() { + auto ret = std::string("uuid:") + + creatString4() + creatString4() + '-' + + creatString4() + '-' + + creatString4() + '-' + + creatString4() + '-' + + creatString4() + creatString4() + creatString4(); + TraceL << ret; + return ret; +} + +static tuple +makePasswordDigest(const string &user_name, const string &passwd) { + std::mt19937 rng(std::random_device{}()); + string nonce; + nonce.resize(16); + for (auto &ch : nonce) { + ch = rng() & 0xFF; + } + auto timestamp = getTimeStr("%Y-%m-%dT%H:%M:%S%z"); + auto passdigest = SHA1::encode_bin(nonce + timestamp + passwd); + return std::make_tuple(encodeBase64(passdigest), encodeBase64(nonce), timestamp); +} + +static std::string +createSoapSecurity(const string &user_name, const string &passdigest, const string &nonce, const string ×tamp) { + static constexpr char str_fmt[] = + R"( + + %s + %s + %s + %s + )"; + return print_to_string(str_fmt, user_name.data(), passdigest.data(), nonce.data(), timestamp.data()); +} + +std::string SoapUtil::createSoapRequest(const string &body, const string &user_name, const string &passwd) { + string header = R"()"; + if (!user_name.empty() && !passwd.empty()) { + auto req = makePasswordDigest(user_name, passwd); + header += createSoapSecurity(user_name, std::get<0>(req), std::get<1>(req), std::get<2>(req)); + } + header += R"()"; + header += body; + header += R"()"; + return header; +} + +SoapErr::SoapErr(std::string url, + std::string action, + SockException ex, + const mediakit::Parser &parser, + std::string err) { + _url = std::move(url); + _action = std::move(action); + _net_err = std::move(ex); + _http_code = atoi(parser.status().data()); + _http_msg = parser.statusStr(); + _other_err = std::move(err); +} + +SoapErr::operator std::string() const { + _StrPrinter printer; + printer << "request onvif service failed, url:" << _url << ", action:" << _action << ", "; + if (_net_err) { + return printer << "network err:" << _net_err.what() << endl; + } + if (_http_code != 200) { + return printer << "http bad status:" << _http_code << " " << _http_msg << endl; + } + if (!_other_err.empty()) { + return printer << _other_err << endl; + } + return ""; +} + +SoapErr::operator bool() const { + return _net_err || _http_code != 200 || !_other_err.empty(); +} + +bool SoapErr::empty() const { + return !*this; +} + +int SoapErr::httpCode() const { + return _http_code; +} + +std::ostream& operator<<(std::ostream& sout, const SoapErr &err) { + sout << (string)err; + return sout; +} + +void SoapUtil::sendSoapRequest(const string &url, const string &SOAPAction, const string &body, const SoapRequestCB &func, + float timeout_sec) { + HttpRequester::Ptr requester(new HttpRequester); + requester->setMethod("POST"); + requester->setBody(body); + requester->addHeader("Content-Type", "text/xml; charset=utf-8; action=" + SOAPAction); + requester->addHeader("Accept", "text/xml; charset=utf-8"); + requester->addHeader("SOAPAction", SOAPAction); + std::shared_ptr ticker(new Ticker); + requester->startRequester(url, [url, SOAPAction, func, requester, ticker](const SockException &ex, + const Parser &parser) mutable { + + onceToken token(nullptr, [&]() mutable { + requester.reset(); + }); + auto invoker = [&](const SoapObject &node, const SoapErr &err) { + if (err) { + WarnL << err; + } + if (func) { + func(node, err); + } + }; + + if (ex) { + invoker(SoapObject(), SoapErr(url, SOAPAction, ex, parser)); + return; + } + if (parser.status() != "200") { + invoker(SoapObject(), SoapErr(url, SOAPAction, ex, parser)); + return; + } + SoapObject root; + try { + root.load(parser.content().data(), parser.content().size()); + } catch (std::exception &e) { + auto err = StrPrinter << "[parse xml failed]:" << e.what() << endl; + invoker(SoapObject(), SoapErr(url, SOAPAction, ex, parser, err)); + return; + } + auto body = root["Envelope/Body"]; + if (!body) { + auto err = StrPrinter << "[invalid onvif soap response]:" << ex.what() << endl; + invoker(SoapObject(), SoapErr(url, SOAPAction, ex, parser, err)); + return; + } + auto fault = body["Fault"]; + if (fault) { + auto err = StrPrinter << "[onvif soap fault]:" << fault["Reason/Text"].as_xml().text().as_string() << endl;; + invoker(SoapObject(), SoapErr(url, SOAPAction, ex, parser, err)); + return; + } + //成功 + invoker(body, SoapErr(url, SOAPAction, ex, parser)); + }, timeout_sec); +} + +void SoapUtil::sendGetDeviceInformation(const std::string &device_service, const std::string &user_name, + const std::string &pwd, SoapRequestCB cb) { + static constexpr char action_url[] = R"("http://www.onvif.org/ver10/device/wsdl/GetDeviceInformation")"; + static constexpr char str_fmt[] = R"()"; + auto body = SoapUtil::createSoapRequest(str_fmt, user_name, pwd); + SoapUtil::sendSoapRequest(device_service, action_url, body, std::move(cb)); +} + +void SoapUtil::sendGetProfiles(bool is_media2, const string &media_url, const string &user_name, const string &pwd, + const onGetProfilesResponse &cb) { + auto invoker = [is_media2, cb](const SoapObject &res, const SoapErr &err) { + if (err) { + cb(err, vector()); + return; + } + multimap sorted; + for (auto &xml_node : res["GetProfilesResponse"].as_xml()) { + SoapObject obj(xml_node, res); + auto profile_name = obj["Name"].as_xml().text().as_string(); + auto token = xml_node.attribute("token"); + if (token) { + profile_name = token.value(); + } + auto codec = obj[string(is_media2 ? "Configurations/VideoEncoder/Encoding" + : "VideoEncoderConfiguration/Encoding")].as_xml().text().as_string(); + auto width = obj[string(is_media2 ? "Configurations/VideoEncoder/Resolution/Width" + : "VideoEncoderConfiguration/Resolution/Width")].as_xml().text().as_int(); + auto height = obj[string(is_media2 ? "Configurations/VideoEncoder/Resolution/Height" + : "VideoEncoderConfiguration/Resolution/Height")].as_xml().text().as_int(); + sorted.emplace(width * height, std::make_tuple(profile_name, codec, width, height)); + } + vector result; + for (auto &pr : sorted) { + result.insert(result.begin(), pr.second); + } + cb(err, result); + }; + static constexpr char action_url[] = R"("http://www.onvif.org/ver10/media/wsdl/GetProfiles")"; + static constexpr char str_fmt[] = R"()"; + static constexpr char action_url2[] = R"("http://www.onvif.org/ver20/media/wsdl/GetProfiles")"; + static constexpr char str_fmt2[] = R"(All)"; + + auto body = SoapUtil::createSoapRequest(is_media2 ? str_fmt2 : str_fmt, user_name, pwd); + SoapUtil::sendSoapRequest(media_url, is_media2 ? action_url2 : action_url, body, invoker); +} + +void SoapUtil::sendGetServices(const std::string &device_service, const std::initializer_list &ns_filter, + const std::string &user_name, const std::string &pwd, const onGetServicesResponse &cb) { + static constexpr char action_url[] = R"("http://www.onvif.org/ver10/device/wsdl/GetServices")"; + static constexpr char str_fmt[] = R"( + true + )"; + set filter = ns_filter; + auto body = SoapUtil::createSoapRequest(str_fmt, user_name, pwd); + SoapUtil::sendSoapRequest(device_service, action_url, body,[filter, cb](const SoapObject &node, const SoapErr &err) { + onGetServicesResponseMap mp; + if (err) { + cb(err, mp); + return; + } + auto res = node["GetServicesResponse"]; + for (auto &xml_node : res.as_xml()) { + SoapObject obj(xml_node, node); + string ns = obj["Namespace"].as_xml().text().as_string(); + string xaddr = obj["XAddr"].as_xml().text().as_string(); + if (filter.find(ns) != filter.end()) { + mp.emplace(ns, xaddr); + } + } + cb(err, mp); + }); +} + +static string getRtspUrlWithAuth(const std::string &user_name, const std::string &pwd, const string &url) { + RtspUrl parser; + parser.parse(url); + if (user_name.empty() || pwd.empty() || !parser._user.empty() || !parser._passwd.empty()) { + return url; + } + auto ret = url; + auto pos = ret.find("://"); + if (pos == string::npos) { + return ret; + } + ret.insert(pos + 3, (user_name + ":" + pwd + "@").data()); + return ret; +} + +void SoapUtil::sendGetStreamUri(bool is_media2, const string &media_url, const string &profile, + const std::string &user_name, const std::string &pwd, + const onGetStreamUriResponse &cb) { + if (!is_media2) { + static constexpr char action_url[] = R"("http://www.onvif.org/ver10/media/wsdl/GetStreamUri")"; + static constexpr char str_fmt[] = R"( + + %s + + %s + + + %s + )"; + + auto body = SoapUtil::createSoapRequest(print_to_string(str_fmt, "RTP-Unicast", "RTSP", profile.data()), + user_name, pwd); + SoapUtil::sendSoapRequest(media_url, action_url, body, [cb](const SoapObject &node, const SoapErr &err) { + if (err) { + cb(err, ""); + return; + } + auto res = node["GetStreamUriResponse/MediaUri/Uri"]; + cb(err, res.as_xml().text().as_string()); + }); + } else { + static constexpr char action_url[] = R"("http://www.onvif.org/ver20/media/wsdl/GetStreamUri")"; + static constexpr char str_fmt[] = R"( + %s + %s + )"; + auto body = SoapUtil::createSoapRequest(print_to_string(str_fmt, "RTSP", profile.data()), user_name, pwd); + SoapUtil::sendSoapRequest(media_url, action_url, body, [cb](const SoapObject &node, const SoapErr &err) { + if (err) { + cb(err, ""); + return; + } + auto res = node["GetStreamUriResponse/Uri"]; + cb(err, res.as_xml().text().as_string()); + }); + } +} + +static void asyncGetStreamUri_l(const std::string &onvif_url, const std::string &user_name, + const std::string &pwd, const SoapUtil::AsyncGetStreamUriCB &cb, + std::shared_ptr retry_count) { + + static constexpr char media_ns[] = "http://www.onvif.org/ver10/media/wsdl"; + static constexpr char media2_ns[] = "http://www.onvif.org/ver20/media/wsdl"; + + auto invoker = [=](const std::string &user_name, const std::string &pwd) { + ++*retry_count; + asyncGetStreamUri_l(onvif_url, user_name, pwd, cb, retry_count); + }; + SoapUtil::sendGetDeviceInformation(onvif_url, user_name, pwd, [=](const SoapObject &body, const SoapErr &err) { + if (err) { + cb(err, invoker, *retry_count, ""); + return; + } + + SoapUtil::sendGetServices(onvif_url, {media_ns, media2_ns}, user_name, pwd, + [=](const SoapErr &err, SoapUtil::onGetServicesResponseMap &mp) { + auto media1_url = mp[media_ns]; + auto media2_url = mp[media2_ns]; + auto media_url = media2_url.empty() ? media1_url : media2_url; + bool is_media2 = media2_url.empty() ? false : true; + if (err) { + cb(err, invoker, *retry_count, ""); + return; + } + if (media_url.empty()) { + static constexpr char action_url[] = R"("http://www.onvif.org/ver10/device/wsdl/GetServices")"; + SoapErr err(onvif_url, action_url, SockException(), mediakit::Parser(), + "get media service failed"); + cb(err, invoker, *retry_count, ""); + return; + } + SoapUtil::sendGetProfiles(is_media2, media_url, user_name, pwd, + [=] (const SoapErr &err, const vector &profile) { + if (err) { + cb(err, invoker, *retry_count, ""); + return; + } + if (profile.empty()) { + static constexpr char action_url[] = R"("http://www.onvif.org/ver10/media/wsdl/GetProfiles")"; + static constexpr char action_url2[] = R"("http://www.onvif.org/ver20/media/wsdl/GetProfiles")"; + SoapErr err(onvif_url, is_media2 ? action_url2 : action_url, SockException(), mediakit::Parser(), + "get media profile failed"); + cb(err, invoker, *retry_count, ""); + return; + } + auto profile_name = get<0>(profile[0]); + SoapUtil::sendGetStreamUri(is_media2, media_url, profile_name, user_name, pwd, + [=](const SoapErr &err, const string &uri) { + if (err) { + cb(err, invoker, *retry_count, ""); + return; + } + cb(err, invoker, *retry_count, getRtspUrlWithAuth(user_name, pwd, uri)); + }); + }); + }); + }); +} + +void SoapUtil::asyncGetStreamUri(const std::string &onvif_url, const SoapUtil::AsyncGetStreamUriCB &cb) { + asyncGetStreamUri_l(onvif_url, "", "", cb, std::make_shared(0)); +} \ No newline at end of file diff --git a/src/Onvif/SoapUtil.h b/src/Onvif/SoapUtil.h new file mode 100644 index 00000000..3d4f8964 --- /dev/null +++ b/src/Onvif/SoapUtil.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_SOAPUTIL_H +#define ZLMEDIAKIT_SOAPUTIL_H + +#include +#include +#include +#include +#include "Common/Parser.h" +#include "Network/Socket.h" +#include "pugixml.hpp" + +struct xml_string_writer : pugi::xml_writer { + std::string result; + virtual void write(const void *data, size_t size) { + result.append(static_cast(data), size); + } +}; + +template +std::string print_to_string(const char (&str_fmt)[sz], ARGS &&...args) { + std::string ret; + //鉴权%s长度再减去\0长度 + ret.resize(2 * sizeof(str_fmt)); + //string的真实内存大小必定比size大一个字节(用于存放\0) + auto size = snprintf((char *) ret.data(), ret.size() + 1, str_fmt, std::forward(args)...); + ret.resize(size); + return ret; +} + +class SoapObject; + +class SoapErr { +public: + SoapErr(std::string url, + std::string action, + toolkit::SockException ex, + const mediakit::Parser &parser, + std::string err = ""); + + operator std::string() const; + operator bool() const; + bool empty() const; + int httpCode() const; + +private: + std::string _url; + std::string _action; + toolkit::SockException _net_err; + int _http_code = 200; + std::string _http_msg; + std::string _other_err; +}; + +std::ostream& operator<<(std::ostream& sout, const SoapErr &err); + +class SoapUtil { +public: + static std::string createDiscoveryString(const std::string &uuid = ""); + static std::string createUuidString(); + static std::string createSoapRequest(const std::string &body, const std::string &user_name = "", const std::string &passwd = ""); + + using SoapRequestCB = std::function; + static void sendSoapRequest(const std::string &url, const std::string &action, const std::string &body, + const SoapRequestCB &func = nullptr, float timeout_sec = 10); + + + using onGetProfilesResponseTuple = std::tuple; + using onGetProfilesResponse = std::function &profile)>; + + /** + * 获取profile + * @param is_media2 是否为media2访问方式 + * @param media_url media服务访问地址 + * @param user_name 用户名 + * @param pwd 密码 + * @param cb 回调, 高分辨率的profile在前 + */ + static void sendGetProfiles(bool is_media2, const std::string &media_url, const std::string &user_name, + const std::string &pwd, const onGetProfilesResponse &cb); + + /** + * 获取设备信息 + * @param device_service device_service服务访问地址 + * @param user_name 用户名 + * @param pwd 密码 + * @param cb 回调 + */ + static void sendGetDeviceInformation(const std::string &device_service, const std::string &user_name, + const std::string &pwd, SoapRequestCB cb); + + + using onGetServicesResponseMap = std::map; + using onGetServicesResponse = std::function; + + /** + * 获取服务url地址 + * @param device_service device_service服务访问地址 + * @param ns_filter 刷选的服务的命名空间 + * @param user_name 用户名 + * @param pwd 密码 + * @param cb 回调 + */ + static void sendGetServices(const std::string &device_service, const std::initializer_list &ns_filter, + const std::string &user_name, const std::string &pwd, const onGetServicesResponse &cb); + + + using onGetStreamUriResponse = std::function; + + /** + * 获取rtsp播放url + * @param is_media2 是否为media2方式 + * @param media_url media或media2服务访问地址 + * @param profile sendGetProfiles接口获取的分辨率方案 + * @param user_name 用户名 + * @param pwd 密码 + * @param cb 回调 + */ + static void sendGetStreamUri(bool is_media2, const std::string &media_url, const std::string &profile, + const std::string &user_name, const std::string &pwd, + const onGetStreamUriResponse &cb); + + using GetStreamUriRetryInvoker = std::function; + using AsyncGetStreamUriCB = std::function; + + /** + * 异步获取播放url + * @param onvif_url 设备搜索时返回的url + * @param cb 回调 + */ + static void asyncGetStreamUri(const std::string &onvif_url, const AsyncGetStreamUriCB &cb); + +private: + SoapUtil() = delete; + ~SoapUtil() = delete; +}; + +class SoapObject { +public: + using Ptr = std::shared_ptr; + + SoapObject(const pugi::xml_node &node, const SoapObject &ref); + SoapObject(); + operator bool () const; + void load(const char *data, size_t len); + SoapObject operator[](const std::string &path) const; + + template + SoapObject operator[](const char (&path)[sz]) const{ + return (*this)[std::string(path, sz - 1)]; + } + + SoapObject operator[](size_t index) const; + std::string as_string() const; + pugi::xml_node as_xml() const; + +private: + SoapObject(std::shared_ptr node); + +private: + std::shared_ptr _root; +}; + + + +#endif //ZLMEDIAKIT_SOAPUTIL_H diff --git a/src/Onvif/pugiconfig.hpp b/src/Onvif/pugiconfig.hpp new file mode 100644 index 00000000..7dd907ed --- /dev/null +++ b/src/Onvif/pugiconfig.hpp @@ -0,0 +1,77 @@ +/** + * pugixml parser - version 1.11 + * -------------------------------------------------------- + * Copyright (C) 2006-2020, by Arseny Kapoulkine (arseny.kapoulkine@gmail.com) + * Report bugs and download new versions at https://pugixml.org/ + * + * This library is distributed under the MIT License. See notice at the end + * of this file. + * + * This work is based on the pugxml parser, which is: + * Copyright (C) 2003, by Kristen Wegner (kristen@tima.net) + */ + +#ifndef HEADER_PUGICONFIG_HPP +#define HEADER_PUGICONFIG_HPP + +// Uncomment this to enable wchar_t mode +// #define PUGIXML_WCHAR_MODE + +// Uncomment this to enable compact mode +// #define PUGIXML_COMPACT + +// Uncomment this to disable XPath +// #define PUGIXML_NO_XPATH + +// Uncomment this to disable STL +// #define PUGIXML_NO_STL + +// Uncomment this to disable exceptions +// #define PUGIXML_NO_EXCEPTIONS + +// Set this to control attributes for public classes/functions, i.e.: +// #define PUGIXML_API __declspec(dllexport) // to export all public symbols from DLL +// #define PUGIXML_CLASS __declspec(dllimport) // to import all classes from DLL +// #define PUGIXML_FUNCTION __fastcall // to set calling conventions to all public functions to fastcall +// In absence of PUGIXML_CLASS/PUGIXML_FUNCTION definitions PUGIXML_API is used instead + +// Tune these constants to adjust memory-related behavior +// #define PUGIXML_MEMORY_PAGE_SIZE 32768 +// #define PUGIXML_MEMORY_OUTPUT_STACK 10240 +// #define PUGIXML_MEMORY_XPATH_PAGE_SIZE 4096 + +// Tune this constant to adjust max nesting for XPath queries +// #define PUGIXML_XPATH_DEPTH_LIMIT 1024 + +// Uncomment this to switch to header-only version +// #define PUGIXML_HEADER_ONLY + +// Uncomment this to enable long long support +// #define PUGIXML_HAS_LONG_LONG + +#endif + +/** + * Copyright (c) 2006-2020 Arseny Kapoulkine + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ diff --git a/src/Onvif/pugixml.cpp b/src/Onvif/pugixml.cpp new file mode 100644 index 00000000..5609ab4e --- /dev/null +++ b/src/Onvif/pugixml.cpp @@ -0,0 +1,13027 @@ +/** + * pugixml parser - version 1.11 + * -------------------------------------------------------- + * Copyright (C) 2006-2020, by Arseny Kapoulkine (arseny.kapoulkine@gmail.com) + * Report bugs and download new versions at https://pugixml.org/ + * + * This library is distributed under the MIT License. See notice at the end + * of this file. + * + * This work is based on the pugxml parser, which is: + * Copyright (C) 2003, by Kristen Wegner (kristen@tima.net) + */ + +#ifndef SOURCE_PUGIXML_CPP +#define SOURCE_PUGIXML_CPP + +#include "pugixml.hpp" + +#include +#include +#include +#include +#include + +#ifdef PUGIXML_WCHAR_MODE +# include +#endif + +#ifndef PUGIXML_NO_XPATH +# include +# include +#endif + +#ifndef PUGIXML_NO_STL +# include +# include +# include +#endif + +// For placement new +#include + +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable: 4127) // conditional expression is constant +# pragma warning(disable: 4324) // structure was padded due to __declspec(align()) +# pragma warning(disable: 4702) // unreachable code +# pragma warning(disable: 4996) // this function or variable may be unsafe +#endif + +#if defined(_MSC_VER) && defined(__c2__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated" // this function or variable may be unsafe +#endif + +#ifdef __INTEL_COMPILER +# pragma warning(disable: 177) // function was declared but never referenced +# pragma warning(disable: 279) // controlling expression is constant +# pragma warning(disable: 1478 1786) // function was declared "deprecated" +# pragma warning(disable: 1684) // conversion from pointer to same-sized integral type +#endif + +#if defined(__BORLANDC__) && defined(PUGIXML_HEADER_ONLY) +# pragma warn -8080 // symbol is declared but never used; disabling this inside push/pop bracket does not make the warning go away +#endif + +#ifdef __BORLANDC__ +# pragma option push +# pragma warn -8008 // condition is always false +# pragma warn -8066 // unreachable code +#endif + +#ifdef __SNC__ +// Using diag_push/diag_pop does not disable the warnings inside templates due to a compiler bug +# pragma diag_suppress=178 // function was declared but never referenced +# pragma diag_suppress=237 // controlling expression is constant +#endif + +#ifdef __TI_COMPILER_VERSION__ +# pragma diag_suppress 179 // function was declared but never referenced +#endif + +// Inlining controls +#if defined(_MSC_VER) && _MSC_VER >= 1300 +# define PUGI__NO_INLINE __declspec(noinline) +#elif defined(__GNUC__) +# define PUGI__NO_INLINE __attribute__((noinline)) +#else +# define PUGI__NO_INLINE +#endif + +// Branch weight controls +#if defined(__GNUC__) && !defined(__c2__) +# define PUGI__UNLIKELY(cond) __builtin_expect(cond, 0) +#else +# define PUGI__UNLIKELY(cond) (cond) +#endif + +// Simple static assertion +#define PUGI__STATIC_ASSERT(cond) { static const char condition_failed[(cond) ? 1 : -1] = {0}; (void)condition_failed[0]; } + +// Digital Mars C++ bug workaround for passing char loaded from memory via stack +#ifdef __DMC__ +# define PUGI__DMC_VOLATILE volatile +#else +# define PUGI__DMC_VOLATILE +#endif + +// Integer sanitizer workaround; we only apply this for clang since gcc8 has no_sanitize but not unsigned-integer-overflow and produces "attribute directive ignored" warnings +#if defined(__clang__) && defined(__has_attribute) +# if __has_attribute(no_sanitize) +# define PUGI__UNSIGNED_OVERFLOW __attribute__((no_sanitize("unsigned-integer-overflow"))) +# else +# define PUGI__UNSIGNED_OVERFLOW +# endif +#else +# define PUGI__UNSIGNED_OVERFLOW +#endif + +// Borland C++ bug workaround for not defining ::memcpy depending on header include order (can't always use std::memcpy because some compilers don't have it at all) +#if defined(__BORLANDC__) && !defined(__MEM_H_USING_LIST) +using std::memcpy; +using std::memmove; +using std::memset; +#endif + +// Some MinGW/GCC versions have headers that erroneously omit LLONG_MIN/LLONG_MAX/ULLONG_MAX definitions from limits.h in some configurations +#if defined(PUGIXML_HAS_LONG_LONG) && defined(__GNUC__) && !defined(LLONG_MAX) && !defined(LLONG_MIN) && !defined(ULLONG_MAX) +# define LLONG_MIN (-LLONG_MAX - 1LL) +# define LLONG_MAX __LONG_LONG_MAX__ +# define ULLONG_MAX (LLONG_MAX * 2ULL + 1ULL) +#endif + +// In some environments MSVC is a compiler but the CRT lacks certain MSVC-specific features +#if defined(_MSC_VER) && !defined(__S3E__) +# define PUGI__MSVC_CRT_VERSION _MSC_VER +#endif + +// Not all platforms have snprintf; we define a wrapper that uses snprintf if possible. This only works with buffers with a known size. +#if __cplusplus >= 201103 +# define PUGI__SNPRINTF(buf, ...) snprintf(buf, sizeof(buf), __VA_ARGS__) +#elif defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 +# define PUGI__SNPRINTF(buf, ...) _snprintf_s(buf, _countof(buf), _TRUNCATE, __VA_ARGS__) +#else +# define PUGI__SNPRINTF sprintf +#endif + +// We put implementation details into an anonymous namespace in source mode, but have to keep it in non-anonymous namespace in header-only mode to prevent binary bloat. +#ifdef PUGIXML_HEADER_ONLY +# define PUGI__NS_BEGIN namespace pugi { namespace impl { +# define PUGI__NS_END } } +# define PUGI__FN inline +# define PUGI__FN_NO_INLINE inline +#else +# if defined(_MSC_VER) && _MSC_VER < 1300 // MSVC6 seems to have an amusing bug with anonymous namespaces inside namespaces +# define PUGI__NS_BEGIN namespace pugi { namespace impl { +# define PUGI__NS_END } } +# else +# define PUGI__NS_BEGIN namespace pugi { namespace impl { namespace { +# define PUGI__NS_END } } } +# endif +# define PUGI__FN +# define PUGI__FN_NO_INLINE PUGI__NO_INLINE +#endif + +// uintptr_t +#if (defined(_MSC_VER) && _MSC_VER < 1600) || (defined(__BORLANDC__) && __BORLANDC__ < 0x561) +namespace pugi +{ +# ifndef _UINTPTR_T_DEFINED + typedef size_t uintptr_t; +# endif + + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; +} +#else +# include +#endif + +// Memory allocation +PUGI__NS_BEGIN + PUGI__FN void* default_allocate(size_t size) + { + return malloc(size); + } + + PUGI__FN void default_deallocate(void* ptr) + { + free(ptr); + } + + template + struct xml_memory_management_function_storage + { + static allocation_function allocate; + static deallocation_function deallocate; + }; + + // Global allocation functions are stored in class statics so that in header mode linker deduplicates them + // Without a template<> we'll get multiple definitions of the same static + template allocation_function xml_memory_management_function_storage::allocate = default_allocate; + template deallocation_function xml_memory_management_function_storage::deallocate = default_deallocate; + + typedef xml_memory_management_function_storage xml_memory; +PUGI__NS_END + +// String utilities +PUGI__NS_BEGIN + // Get string length + PUGI__FN size_t strlength(const char_t* s) + { + assert(s); + + #ifdef PUGIXML_WCHAR_MODE + return wcslen(s); + #else + return strlen(s); + #endif + } + + // Compare two strings + PUGI__FN bool strequal(const char_t* src, const char_t* dst) + { + assert(src && dst); + + #ifdef PUGIXML_WCHAR_MODE + return wcscmp(src, dst) == 0; + #else + return strcmp(src, dst) == 0; + #endif + } + + // Compare lhs with [rhs_begin, rhs_end) + PUGI__FN bool strequalrange(const char_t* lhs, const char_t* rhs, size_t count) + { + for (size_t i = 0; i < count; ++i) + if (lhs[i] != rhs[i]) + return false; + + return lhs[count] == 0; + } + + // Get length of wide string, even if CRT lacks wide character support + PUGI__FN size_t strlength_wide(const wchar_t* s) + { + assert(s); + + #ifdef PUGIXML_WCHAR_MODE + return wcslen(s); + #else + const wchar_t* end = s; + while (*end) end++; + return static_cast(end - s); + #endif + } +PUGI__NS_END + +// auto_ptr-like object for exception recovery +PUGI__NS_BEGIN + template struct auto_deleter + { + typedef void (*D)(T*); + + T* data; + D deleter; + + auto_deleter(T* data_, D deleter_): data(data_), deleter(deleter_) + { + } + + ~auto_deleter() + { + if (data) deleter(data); + } + + T* release() + { + T* result = data; + data = 0; + return result; + } + }; +PUGI__NS_END + +#ifdef PUGIXML_COMPACT +PUGI__NS_BEGIN + class compact_hash_table + { + public: + compact_hash_table(): _items(0), _capacity(0), _count(0) + { + } + + void clear() + { + if (_items) + { + xml_memory::deallocate(_items); + _items = 0; + _capacity = 0; + _count = 0; + } + } + + void* find(const void* key) + { + if (_capacity == 0) return 0; + + item_t* item = get_item(key); + assert(item); + assert(item->key == key || (item->key == 0 && item->value == 0)); + + return item->value; + } + + void insert(const void* key, void* value) + { + assert(_capacity != 0 && _count < _capacity - _capacity / 4); + + item_t* item = get_item(key); + assert(item); + + if (item->key == 0) + { + _count++; + item->key = key; + } + + item->value = value; + } + + bool reserve(size_t extra = 16) + { + if (_count + extra >= _capacity - _capacity / 4) + return rehash(_count + extra); + + return true; + } + + private: + struct item_t + { + const void* key; + void* value; + }; + + item_t* _items; + size_t _capacity; + + size_t _count; + + bool rehash(size_t count); + + item_t* get_item(const void* key) + { + assert(key); + assert(_capacity > 0); + + size_t hashmod = _capacity - 1; + size_t bucket = hash(key) & hashmod; + + for (size_t probe = 0; probe <= hashmod; ++probe) + { + item_t& probe_item = _items[bucket]; + + if (probe_item.key == key || probe_item.key == 0) + return &probe_item; + + // hash collision, quadratic probing + bucket = (bucket + probe + 1) & hashmod; + } + + assert(false && "Hash table is full"); // unreachable + return 0; + } + + static PUGI__UNSIGNED_OVERFLOW unsigned int hash(const void* key) + { + unsigned int h = static_cast(reinterpret_cast(key) & 0xffffffff); + + // MurmurHash3 32-bit finalizer + h ^= h >> 16; + h *= 0x85ebca6bu; + h ^= h >> 13; + h *= 0xc2b2ae35u; + h ^= h >> 16; + + return h; + } + }; + + PUGI__FN_NO_INLINE bool compact_hash_table::rehash(size_t count) + { + size_t capacity = 32; + while (count >= capacity - capacity / 4) + capacity *= 2; + + compact_hash_table rt; + rt._capacity = capacity; + rt._items = static_cast(xml_memory::allocate(sizeof(item_t) * capacity)); + + if (!rt._items) + return false; + + memset(rt._items, 0, sizeof(item_t) * capacity); + + for (size_t i = 0; i < _capacity; ++i) + if (_items[i].key) + rt.insert(_items[i].key, _items[i].value); + + if (_items) + xml_memory::deallocate(_items); + + _capacity = capacity; + _items = rt._items; + + assert(_count == rt._count); + + return true; + } + +PUGI__NS_END +#endif + +PUGI__NS_BEGIN +#ifdef PUGIXML_COMPACT + static const uintptr_t xml_memory_block_alignment = 4; +#else + static const uintptr_t xml_memory_block_alignment = sizeof(void*); +#endif + + // extra metadata bits + static const uintptr_t xml_memory_page_contents_shared_mask = 64; + static const uintptr_t xml_memory_page_name_allocated_mask = 32; + static const uintptr_t xml_memory_page_value_allocated_mask = 16; + static const uintptr_t xml_memory_page_type_mask = 15; + + // combined masks for string uniqueness + static const uintptr_t xml_memory_page_name_allocated_or_shared_mask = xml_memory_page_name_allocated_mask | xml_memory_page_contents_shared_mask; + static const uintptr_t xml_memory_page_value_allocated_or_shared_mask = xml_memory_page_value_allocated_mask | xml_memory_page_contents_shared_mask; + +#ifdef PUGIXML_COMPACT + #define PUGI__GETHEADER_IMPL(object, page, flags) // unused + #define PUGI__GETPAGE_IMPL(header) (header).get_page() +#else + #define PUGI__GETHEADER_IMPL(object, page, flags) (((reinterpret_cast(object) - reinterpret_cast(page)) << 8) | (flags)) + // this macro casts pointers through void* to avoid 'cast increases required alignment of target type' warnings + #define PUGI__GETPAGE_IMPL(header) static_cast(const_cast(static_cast(reinterpret_cast(&header) - (header >> 8)))) +#endif + + #define PUGI__GETPAGE(n) PUGI__GETPAGE_IMPL((n)->header) + #define PUGI__NODETYPE(n) static_cast((n)->header & impl::xml_memory_page_type_mask) + + struct xml_allocator; + + struct xml_memory_page + { + static xml_memory_page* construct(void* memory) + { + xml_memory_page* result = static_cast(memory); + + result->allocator = 0; + result->prev = 0; + result->next = 0; + result->busy_size = 0; + result->freed_size = 0; + + #ifdef PUGIXML_COMPACT + result->compact_string_base = 0; + result->compact_shared_parent = 0; + result->compact_page_marker = 0; + #endif + + return result; + } + + xml_allocator* allocator; + + xml_memory_page* prev; + xml_memory_page* next; + + size_t busy_size; + size_t freed_size; + + #ifdef PUGIXML_COMPACT + char_t* compact_string_base; + void* compact_shared_parent; + uint32_t* compact_page_marker; + #endif + }; + + static const size_t xml_memory_page_size = + #ifdef PUGIXML_MEMORY_PAGE_SIZE + (PUGIXML_MEMORY_PAGE_SIZE) + #else + 32768 + #endif + - sizeof(xml_memory_page); + + struct xml_memory_string_header + { + uint16_t page_offset; // offset from page->data + uint16_t full_size; // 0 if string occupies whole page + }; + + struct xml_allocator + { + xml_allocator(xml_memory_page* root): _root(root), _busy_size(root->busy_size) + { + #ifdef PUGIXML_COMPACT + _hash = 0; + #endif + } + + xml_memory_page* allocate_page(size_t data_size) + { + size_t size = sizeof(xml_memory_page) + data_size; + + // allocate block with some alignment, leaving memory for worst-case padding + void* memory = xml_memory::allocate(size); + if (!memory) return 0; + + // prepare page structure + xml_memory_page* page = xml_memory_page::construct(memory); + assert(page); + + assert(this == _root->allocator); + page->allocator = this; + + return page; + } + + static void deallocate_page(xml_memory_page* page) + { + xml_memory::deallocate(page); + } + + void* allocate_memory_oob(size_t size, xml_memory_page*& out_page); + + void* allocate_memory(size_t size, xml_memory_page*& out_page) + { + if (PUGI__UNLIKELY(_busy_size + size > xml_memory_page_size)) + return allocate_memory_oob(size, out_page); + + void* buf = reinterpret_cast(_root) + sizeof(xml_memory_page) + _busy_size; + + _busy_size += size; + + out_page = _root; + + return buf; + } + + #ifdef PUGIXML_COMPACT + void* allocate_object(size_t size, xml_memory_page*& out_page) + { + void* result = allocate_memory(size + sizeof(uint32_t), out_page); + if (!result) return 0; + + // adjust for marker + ptrdiff_t offset = static_cast(result) - reinterpret_cast(out_page->compact_page_marker); + + if (PUGI__UNLIKELY(static_cast(offset) >= 256 * xml_memory_block_alignment)) + { + // insert new marker + uint32_t* marker = static_cast(result); + + *marker = static_cast(reinterpret_cast(marker) - reinterpret_cast(out_page)); + out_page->compact_page_marker = marker; + + // since we don't reuse the page space until we reallocate it, we can just pretend that we freed the marker block + // this will make sure deallocate_memory correctly tracks the size + out_page->freed_size += sizeof(uint32_t); + + return marker + 1; + } + else + { + // roll back uint32_t part + _busy_size -= sizeof(uint32_t); + + return result; + } + } + #else + void* allocate_object(size_t size, xml_memory_page*& out_page) + { + return allocate_memory(size, out_page); + } + #endif + + void deallocate_memory(void* ptr, size_t size, xml_memory_page* page) + { + if (page == _root) page->busy_size = _busy_size; + + assert(ptr >= reinterpret_cast(page) + sizeof(xml_memory_page) && ptr < reinterpret_cast(page) + sizeof(xml_memory_page) + page->busy_size); + (void)!ptr; + + page->freed_size += size; + assert(page->freed_size <= page->busy_size); + + if (page->freed_size == page->busy_size) + { + if (page->next == 0) + { + assert(_root == page); + + // top page freed, just reset sizes + page->busy_size = 0; + page->freed_size = 0; + + #ifdef PUGIXML_COMPACT + // reset compact state to maximize efficiency + page->compact_string_base = 0; + page->compact_shared_parent = 0; + page->compact_page_marker = 0; + #endif + + _busy_size = 0; + } + else + { + assert(_root != page); + assert(page->prev); + + // remove from the list + page->prev->next = page->next; + page->next->prev = page->prev; + + // deallocate + deallocate_page(page); + } + } + } + + char_t* allocate_string(size_t length) + { + static const size_t max_encoded_offset = (1 << 16) * xml_memory_block_alignment; + + PUGI__STATIC_ASSERT(xml_memory_page_size <= max_encoded_offset); + + // allocate memory for string and header block + size_t size = sizeof(xml_memory_string_header) + length * sizeof(char_t); + + // round size up to block alignment boundary + size_t full_size = (size + (xml_memory_block_alignment - 1)) & ~(xml_memory_block_alignment - 1); + + xml_memory_page* page; + xml_memory_string_header* header = static_cast(allocate_memory(full_size, page)); + + if (!header) return 0; + + // setup header + ptrdiff_t page_offset = reinterpret_cast(header) - reinterpret_cast(page) - sizeof(xml_memory_page); + + assert(page_offset % xml_memory_block_alignment == 0); + assert(page_offset >= 0 && static_cast(page_offset) < max_encoded_offset); + header->page_offset = static_cast(static_cast(page_offset) / xml_memory_block_alignment); + + // full_size == 0 for large strings that occupy the whole page + assert(full_size % xml_memory_block_alignment == 0); + assert(full_size < max_encoded_offset || (page->busy_size == full_size && page_offset == 0)); + header->full_size = static_cast(full_size < max_encoded_offset ? full_size / xml_memory_block_alignment : 0); + + // round-trip through void* to avoid 'cast increases required alignment of target type' warning + // header is guaranteed a pointer-sized alignment, which should be enough for char_t + return static_cast(static_cast(header + 1)); + } + + void deallocate_string(char_t* string) + { + // this function casts pointers through void* to avoid 'cast increases required alignment of target type' warnings + // we're guaranteed the proper (pointer-sized) alignment on the input string if it was allocated via allocate_string + + // get header + xml_memory_string_header* header = static_cast(static_cast(string)) - 1; + assert(header); + + // deallocate + size_t page_offset = sizeof(xml_memory_page) + header->page_offset * xml_memory_block_alignment; + xml_memory_page* page = reinterpret_cast(static_cast(reinterpret_cast(header) - page_offset)); + + // if full_size == 0 then this string occupies the whole page + size_t full_size = header->full_size == 0 ? page->busy_size : header->full_size * xml_memory_block_alignment; + + deallocate_memory(header, full_size, page); + } + + bool reserve() + { + #ifdef PUGIXML_COMPACT + return _hash->reserve(); + #else + return true; + #endif + } + + xml_memory_page* _root; + size_t _busy_size; + + #ifdef PUGIXML_COMPACT + compact_hash_table* _hash; + #endif + }; + + PUGI__FN_NO_INLINE void* xml_allocator::allocate_memory_oob(size_t size, xml_memory_page*& out_page) + { + const size_t large_allocation_threshold = xml_memory_page_size / 4; + + xml_memory_page* page = allocate_page(size <= large_allocation_threshold ? xml_memory_page_size : size); + out_page = page; + + if (!page) return 0; + + if (size <= large_allocation_threshold) + { + _root->busy_size = _busy_size; + + // insert page at the end of linked list + page->prev = _root; + _root->next = page; + _root = page; + + _busy_size = size; + } + else + { + // insert page before the end of linked list, so that it is deleted as soon as possible + // the last page is not deleted even if it's empty (see deallocate_memory) + assert(_root->prev); + + page->prev = _root->prev; + page->next = _root; + + _root->prev->next = page; + _root->prev = page; + + page->busy_size = size; + } + + return reinterpret_cast(page) + sizeof(xml_memory_page); + } +PUGI__NS_END + +#ifdef PUGIXML_COMPACT +PUGI__NS_BEGIN + static const uintptr_t compact_alignment_log2 = 2; + static const uintptr_t compact_alignment = 1 << compact_alignment_log2; + + class compact_header + { + public: + compact_header(xml_memory_page* page, unsigned int flags) + { + PUGI__STATIC_ASSERT(xml_memory_block_alignment == compact_alignment); + + ptrdiff_t offset = (reinterpret_cast(this) - reinterpret_cast(page->compact_page_marker)); + assert(offset % compact_alignment == 0 && static_cast(offset) < 256 * compact_alignment); + + _page = static_cast(offset >> compact_alignment_log2); + _flags = static_cast(flags); + } + + void operator&=(uintptr_t mod) + { + _flags &= static_cast(mod); + } + + void operator|=(uintptr_t mod) + { + _flags |= static_cast(mod); + } + + uintptr_t operator&(uintptr_t mod) const + { + return _flags & mod; + } + + xml_memory_page* get_page() const + { + // round-trip through void* to silence 'cast increases required alignment of target type' warnings + const char* page_marker = reinterpret_cast(this) - (_page << compact_alignment_log2); + const char* page = page_marker - *reinterpret_cast(static_cast(page_marker)); + + return const_cast(reinterpret_cast(static_cast(page))); + } + + private: + unsigned char _page; + unsigned char _flags; + }; + + PUGI__FN xml_memory_page* compact_get_page(const void* object, int header_offset) + { + const compact_header* header = reinterpret_cast(static_cast(object) - header_offset); + + return header->get_page(); + } + + template PUGI__FN_NO_INLINE T* compact_get_value(const void* object) + { + return static_cast(compact_get_page(object, header_offset)->allocator->_hash->find(object)); + } + + template PUGI__FN_NO_INLINE void compact_set_value(const void* object, T* value) + { + compact_get_page(object, header_offset)->allocator->_hash->insert(object, value); + } + + template class compact_pointer + { + public: + compact_pointer(): _data(0) + { + } + + void operator=(const compact_pointer& rhs) + { + *this = rhs + 0; + } + + void operator=(T* value) + { + if (value) + { + // value is guaranteed to be compact-aligned; 'this' is not + // our decoding is based on 'this' aligned to compact alignment downwards (see operator T*) + // so for negative offsets (e.g. -3) we need to adjust the diff by compact_alignment - 1 to + // compensate for arithmetic shift rounding for negative values + ptrdiff_t diff = reinterpret_cast(value) - reinterpret_cast(this); + ptrdiff_t offset = ((diff + int(compact_alignment - 1)) >> compact_alignment_log2) - start; + + if (static_cast(offset) <= 253) + _data = static_cast(offset + 1); + else + { + compact_set_value(this, value); + + _data = 255; + } + } + else + _data = 0; + } + + operator T*() const + { + if (_data) + { + if (_data < 255) + { + uintptr_t base = reinterpret_cast(this) & ~(compact_alignment - 1); + + return reinterpret_cast(base + (_data - 1 + start) * compact_alignment); + } + else + return compact_get_value(this); + } + else + return 0; + } + + T* operator->() const + { + return *this; + } + + private: + unsigned char _data; + }; + + template class compact_pointer_parent + { + public: + compact_pointer_parent(): _data(0) + { + } + + void operator=(const compact_pointer_parent& rhs) + { + *this = rhs + 0; + } + + void operator=(T* value) + { + if (value) + { + // value is guaranteed to be compact-aligned; 'this' is not + // our decoding is based on 'this' aligned to compact alignment downwards (see operator T*) + // so for negative offsets (e.g. -3) we need to adjust the diff by compact_alignment - 1 to + // compensate for arithmetic shift behavior for negative values + ptrdiff_t diff = reinterpret_cast(value) - reinterpret_cast(this); + ptrdiff_t offset = ((diff + int(compact_alignment - 1)) >> compact_alignment_log2) + 65533; + + if (static_cast(offset) <= 65533) + { + _data = static_cast(offset + 1); + } + else + { + xml_memory_page* page = compact_get_page(this, header_offset); + + if (PUGI__UNLIKELY(page->compact_shared_parent == 0)) + page->compact_shared_parent = value; + + if (page->compact_shared_parent == value) + { + _data = 65534; + } + else + { + compact_set_value(this, value); + + _data = 65535; + } + } + } + else + { + _data = 0; + } + } + + operator T*() const + { + if (_data) + { + if (_data < 65534) + { + uintptr_t base = reinterpret_cast(this) & ~(compact_alignment - 1); + + return reinterpret_cast(base + (_data - 1 - 65533) * compact_alignment); + } + else if (_data == 65534) + return static_cast(compact_get_page(this, header_offset)->compact_shared_parent); + else + return compact_get_value(this); + } + else + return 0; + } + + T* operator->() const + { + return *this; + } + + private: + uint16_t _data; + }; + + template class compact_string + { + public: + compact_string(): _data(0) + { + } + + void operator=(const compact_string& rhs) + { + *this = rhs + 0; + } + + void operator=(char_t* value) + { + if (value) + { + xml_memory_page* page = compact_get_page(this, header_offset); + + if (PUGI__UNLIKELY(page->compact_string_base == 0)) + page->compact_string_base = value; + + ptrdiff_t offset = value - page->compact_string_base; + + if (static_cast(offset) < (65535 << 7)) + { + // round-trip through void* to silence 'cast increases required alignment of target type' warnings + uint16_t* base = reinterpret_cast(static_cast(reinterpret_cast(this) - base_offset)); + + if (*base == 0) + { + *base = static_cast((offset >> 7) + 1); + _data = static_cast((offset & 127) + 1); + } + else + { + ptrdiff_t remainder = offset - ((*base - 1) << 7); + + if (static_cast(remainder) <= 253) + { + _data = static_cast(remainder + 1); + } + else + { + compact_set_value(this, value); + + _data = 255; + } + } + } + else + { + compact_set_value(this, value); + + _data = 255; + } + } + else + { + _data = 0; + } + } + + operator char_t*() const + { + if (_data) + { + if (_data < 255) + { + xml_memory_page* page = compact_get_page(this, header_offset); + + // round-trip through void* to silence 'cast increases required alignment of target type' warnings + const uint16_t* base = reinterpret_cast(static_cast(reinterpret_cast(this) - base_offset)); + assert(*base); + + ptrdiff_t offset = ((*base - 1) << 7) + (_data - 1); + + return page->compact_string_base + offset; + } + else + { + return compact_get_value(this); + } + } + else + return 0; + } + + private: + unsigned char _data; + }; +PUGI__NS_END +#endif + +#ifdef PUGIXML_COMPACT +namespace pugi +{ + struct xml_attribute_struct + { + xml_attribute_struct(impl::xml_memory_page* page): header(page, 0), namevalue_base(0) + { + PUGI__STATIC_ASSERT(sizeof(xml_attribute_struct) == 8); + } + + impl::compact_header header; + + uint16_t namevalue_base; + + impl::compact_string<4, 2> name; + impl::compact_string<5, 3> value; + + impl::compact_pointer prev_attribute_c; + impl::compact_pointer next_attribute; + }; + + struct xml_node_struct + { + xml_node_struct(impl::xml_memory_page* page, xml_node_type type): header(page, type), namevalue_base(0) + { + PUGI__STATIC_ASSERT(sizeof(xml_node_struct) == 12); + } + + impl::compact_header header; + + uint16_t namevalue_base; + + impl::compact_string<4, 2> name; + impl::compact_string<5, 3> value; + + impl::compact_pointer_parent parent; + + impl::compact_pointer first_child; + + impl::compact_pointer prev_sibling_c; + impl::compact_pointer next_sibling; + + impl::compact_pointer first_attribute; + }; +} +#else +namespace pugi +{ + struct xml_attribute_struct + { + xml_attribute_struct(impl::xml_memory_page* page): name(0), value(0), prev_attribute_c(0), next_attribute(0) + { + header = PUGI__GETHEADER_IMPL(this, page, 0); + } + + uintptr_t header; + + char_t* name; + char_t* value; + + xml_attribute_struct* prev_attribute_c; + xml_attribute_struct* next_attribute; + }; + + struct xml_node_struct + { + xml_node_struct(impl::xml_memory_page* page, xml_node_type type): name(0), value(0), parent(0), first_child(0), prev_sibling_c(0), next_sibling(0), first_attribute(0) + { + header = PUGI__GETHEADER_IMPL(this, page, type); + } + + uintptr_t header; + + char_t* name; + char_t* value; + + xml_node_struct* parent; + + xml_node_struct* first_child; + + xml_node_struct* prev_sibling_c; + xml_node_struct* next_sibling; + + xml_attribute_struct* first_attribute; + }; +} +#endif + +PUGI__NS_BEGIN + struct xml_extra_buffer + { + char_t* buffer; + xml_extra_buffer* next; + }; + + struct xml_document_struct: public xml_node_struct, public xml_allocator + { + xml_document_struct(xml_memory_page* page): xml_node_struct(page, node_document), xml_allocator(page), buffer(0), extra_buffers(0) + { + } + + const char_t* buffer; + + xml_extra_buffer* extra_buffers; + + #ifdef PUGIXML_COMPACT + compact_hash_table hash; + #endif + }; + + template inline xml_allocator& get_allocator(const Object* object) + { + assert(object); + + return *PUGI__GETPAGE(object)->allocator; + } + + template inline xml_document_struct& get_document(const Object* object) + { + assert(object); + + return *static_cast(PUGI__GETPAGE(object)->allocator); + } +PUGI__NS_END + +// Low-level DOM operations +PUGI__NS_BEGIN + inline xml_attribute_struct* allocate_attribute(xml_allocator& alloc) + { + xml_memory_page* page; + void* memory = alloc.allocate_object(sizeof(xml_attribute_struct), page); + if (!memory) return 0; + + return new (memory) xml_attribute_struct(page); + } + + inline xml_node_struct* allocate_node(xml_allocator& alloc, xml_node_type type) + { + xml_memory_page* page; + void* memory = alloc.allocate_object(sizeof(xml_node_struct), page); + if (!memory) return 0; + + return new (memory) xml_node_struct(page, type); + } + + inline void destroy_attribute(xml_attribute_struct* a, xml_allocator& alloc) + { + if (a->header & impl::xml_memory_page_name_allocated_mask) + alloc.deallocate_string(a->name); + + if (a->header & impl::xml_memory_page_value_allocated_mask) + alloc.deallocate_string(a->value); + + alloc.deallocate_memory(a, sizeof(xml_attribute_struct), PUGI__GETPAGE(a)); + } + + inline void destroy_node(xml_node_struct* n, xml_allocator& alloc) + { + if (n->header & impl::xml_memory_page_name_allocated_mask) + alloc.deallocate_string(n->name); + + if (n->header & impl::xml_memory_page_value_allocated_mask) + alloc.deallocate_string(n->value); + + for (xml_attribute_struct* attr = n->first_attribute; attr; ) + { + xml_attribute_struct* next = attr->next_attribute; + + destroy_attribute(attr, alloc); + + attr = next; + } + + for (xml_node_struct* child = n->first_child; child; ) + { + xml_node_struct* next = child->next_sibling; + + destroy_node(child, alloc); + + child = next; + } + + alloc.deallocate_memory(n, sizeof(xml_node_struct), PUGI__GETPAGE(n)); + } + + inline void append_node(xml_node_struct* child, xml_node_struct* node) + { + child->parent = node; + + xml_node_struct* head = node->first_child; + + if (head) + { + xml_node_struct* tail = head->prev_sibling_c; + + tail->next_sibling = child; + child->prev_sibling_c = tail; + head->prev_sibling_c = child; + } + else + { + node->first_child = child; + child->prev_sibling_c = child; + } + } + + inline void prepend_node(xml_node_struct* child, xml_node_struct* node) + { + child->parent = node; + + xml_node_struct* head = node->first_child; + + if (head) + { + child->prev_sibling_c = head->prev_sibling_c; + head->prev_sibling_c = child; + } + else + child->prev_sibling_c = child; + + child->next_sibling = head; + node->first_child = child; + } + + inline void insert_node_after(xml_node_struct* child, xml_node_struct* node) + { + xml_node_struct* parent = node->parent; + + child->parent = parent; + + if (node->next_sibling) + node->next_sibling->prev_sibling_c = child; + else + parent->first_child->prev_sibling_c = child; + + child->next_sibling = node->next_sibling; + child->prev_sibling_c = node; + + node->next_sibling = child; + } + + inline void insert_node_before(xml_node_struct* child, xml_node_struct* node) + { + xml_node_struct* parent = node->parent; + + child->parent = parent; + + if (node->prev_sibling_c->next_sibling) + node->prev_sibling_c->next_sibling = child; + else + parent->first_child = child; + + child->prev_sibling_c = node->prev_sibling_c; + child->next_sibling = node; + + node->prev_sibling_c = child; + } + + inline void remove_node(xml_node_struct* node) + { + xml_node_struct* parent = node->parent; + + if (node->next_sibling) + node->next_sibling->prev_sibling_c = node->prev_sibling_c; + else + parent->first_child->prev_sibling_c = node->prev_sibling_c; + + if (node->prev_sibling_c->next_sibling) + node->prev_sibling_c->next_sibling = node->next_sibling; + else + parent->first_child = node->next_sibling; + + node->parent = 0; + node->prev_sibling_c = 0; + node->next_sibling = 0; + } + + inline void append_attribute(xml_attribute_struct* attr, xml_node_struct* node) + { + xml_attribute_struct* head = node->first_attribute; + + if (head) + { + xml_attribute_struct* tail = head->prev_attribute_c; + + tail->next_attribute = attr; + attr->prev_attribute_c = tail; + head->prev_attribute_c = attr; + } + else + { + node->first_attribute = attr; + attr->prev_attribute_c = attr; + } + } + + inline void prepend_attribute(xml_attribute_struct* attr, xml_node_struct* node) + { + xml_attribute_struct* head = node->first_attribute; + + if (head) + { + attr->prev_attribute_c = head->prev_attribute_c; + head->prev_attribute_c = attr; + } + else + attr->prev_attribute_c = attr; + + attr->next_attribute = head; + node->first_attribute = attr; + } + + inline void insert_attribute_after(xml_attribute_struct* attr, xml_attribute_struct* place, xml_node_struct* node) + { + if (place->next_attribute) + place->next_attribute->prev_attribute_c = attr; + else + node->first_attribute->prev_attribute_c = attr; + + attr->next_attribute = place->next_attribute; + attr->prev_attribute_c = place; + place->next_attribute = attr; + } + + inline void insert_attribute_before(xml_attribute_struct* attr, xml_attribute_struct* place, xml_node_struct* node) + { + if (place->prev_attribute_c->next_attribute) + place->prev_attribute_c->next_attribute = attr; + else + node->first_attribute = attr; + + attr->prev_attribute_c = place->prev_attribute_c; + attr->next_attribute = place; + place->prev_attribute_c = attr; + } + + inline void remove_attribute(xml_attribute_struct* attr, xml_node_struct* node) + { + if (attr->next_attribute) + attr->next_attribute->prev_attribute_c = attr->prev_attribute_c; + else + node->first_attribute->prev_attribute_c = attr->prev_attribute_c; + + if (attr->prev_attribute_c->next_attribute) + attr->prev_attribute_c->next_attribute = attr->next_attribute; + else + node->first_attribute = attr->next_attribute; + + attr->prev_attribute_c = 0; + attr->next_attribute = 0; + } + + PUGI__FN_NO_INLINE xml_node_struct* append_new_node(xml_node_struct* node, xml_allocator& alloc, xml_node_type type = node_element) + { + if (!alloc.reserve()) return 0; + + xml_node_struct* child = allocate_node(alloc, type); + if (!child) return 0; + + append_node(child, node); + + return child; + } + + PUGI__FN_NO_INLINE xml_attribute_struct* append_new_attribute(xml_node_struct* node, xml_allocator& alloc) + { + if (!alloc.reserve()) return 0; + + xml_attribute_struct* attr = allocate_attribute(alloc); + if (!attr) return 0; + + append_attribute(attr, node); + + return attr; + } +PUGI__NS_END + +// Helper classes for code generation +PUGI__NS_BEGIN + struct opt_false + { + enum { value = 0 }; + }; + + struct opt_true + { + enum { value = 1 }; + }; +PUGI__NS_END + +// Unicode utilities +PUGI__NS_BEGIN + inline uint16_t endian_swap(uint16_t value) + { + return static_cast(((value & 0xff) << 8) | (value >> 8)); + } + + inline uint32_t endian_swap(uint32_t value) + { + return ((value & 0xff) << 24) | ((value & 0xff00) << 8) | ((value & 0xff0000) >> 8) | (value >> 24); + } + + struct utf8_counter + { + typedef size_t value_type; + + static value_type low(value_type result, uint32_t ch) + { + // U+0000..U+007F + if (ch < 0x80) return result + 1; + // U+0080..U+07FF + else if (ch < 0x800) return result + 2; + // U+0800..U+FFFF + else return result + 3; + } + + static value_type high(value_type result, uint32_t) + { + // U+10000..U+10FFFF + return result + 4; + } + }; + + struct utf8_writer + { + typedef uint8_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + // U+0000..U+007F + if (ch < 0x80) + { + *result = static_cast(ch); + return result + 1; + } + // U+0080..U+07FF + else if (ch < 0x800) + { + result[0] = static_cast(0xC0 | (ch >> 6)); + result[1] = static_cast(0x80 | (ch & 0x3F)); + return result + 2; + } + // U+0800..U+FFFF + else + { + result[0] = static_cast(0xE0 | (ch >> 12)); + result[1] = static_cast(0x80 | ((ch >> 6) & 0x3F)); + result[2] = static_cast(0x80 | (ch & 0x3F)); + return result + 3; + } + } + + static value_type high(value_type result, uint32_t ch) + { + // U+10000..U+10FFFF + result[0] = static_cast(0xF0 | (ch >> 18)); + result[1] = static_cast(0x80 | ((ch >> 12) & 0x3F)); + result[2] = static_cast(0x80 | ((ch >> 6) & 0x3F)); + result[3] = static_cast(0x80 | (ch & 0x3F)); + return result + 4; + } + + static value_type any(value_type result, uint32_t ch) + { + return (ch < 0x10000) ? low(result, ch) : high(result, ch); + } + }; + + struct utf16_counter + { + typedef size_t value_type; + + static value_type low(value_type result, uint32_t) + { + return result + 1; + } + + static value_type high(value_type result, uint32_t) + { + return result + 2; + } + }; + + struct utf16_writer + { + typedef uint16_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + *result = static_cast(ch); + + return result + 1; + } + + static value_type high(value_type result, uint32_t ch) + { + uint32_t msh = static_cast(ch - 0x10000) >> 10; + uint32_t lsh = static_cast(ch - 0x10000) & 0x3ff; + + result[0] = static_cast(0xD800 + msh); + result[1] = static_cast(0xDC00 + lsh); + + return result + 2; + } + + static value_type any(value_type result, uint32_t ch) + { + return (ch < 0x10000) ? low(result, ch) : high(result, ch); + } + }; + + struct utf32_counter + { + typedef size_t value_type; + + static value_type low(value_type result, uint32_t) + { + return result + 1; + } + + static value_type high(value_type result, uint32_t) + { + return result + 1; + } + }; + + struct utf32_writer + { + typedef uint32_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + *result = ch; + + return result + 1; + } + + static value_type high(value_type result, uint32_t ch) + { + *result = ch; + + return result + 1; + } + + static value_type any(value_type result, uint32_t ch) + { + *result = ch; + + return result + 1; + } + }; + + struct latin1_writer + { + typedef uint8_t* value_type; + + static value_type low(value_type result, uint32_t ch) + { + *result = static_cast(ch > 255 ? '?' : ch); + + return result + 1; + } + + static value_type high(value_type result, uint32_t ch) + { + (void)ch; + + *result = '?'; + + return result + 1; + } + }; + + struct utf8_decoder + { + typedef uint8_t type; + + template static inline typename Traits::value_type process(const uint8_t* data, size_t size, typename Traits::value_type result, Traits) + { + const uint8_t utf8_byte_mask = 0x3f; + + while (size) + { + uint8_t lead = *data; + + // 0xxxxxxx -> U+0000..U+007F + if (lead < 0x80) + { + result = Traits::low(result, lead); + data += 1; + size -= 1; + + // process aligned single-byte (ascii) blocks + if ((reinterpret_cast(data) & 3) == 0) + { + // round-trip through void* to silence 'cast increases required alignment of target type' warnings + while (size >= 4 && (*static_cast(static_cast(data)) & 0x80808080) == 0) + { + result = Traits::low(result, data[0]); + result = Traits::low(result, data[1]); + result = Traits::low(result, data[2]); + result = Traits::low(result, data[3]); + data += 4; + size -= 4; + } + } + } + // 110xxxxx -> U+0080..U+07FF + else if (static_cast(lead - 0xC0) < 0x20 && size >= 2 && (data[1] & 0xc0) == 0x80) + { + result = Traits::low(result, ((lead & ~0xC0) << 6) | (data[1] & utf8_byte_mask)); + data += 2; + size -= 2; + } + // 1110xxxx -> U+0800-U+FFFF + else if (static_cast(lead - 0xE0) < 0x10 && size >= 3 && (data[1] & 0xc0) == 0x80 && (data[2] & 0xc0) == 0x80) + { + result = Traits::low(result, ((lead & ~0xE0) << 12) | ((data[1] & utf8_byte_mask) << 6) | (data[2] & utf8_byte_mask)); + data += 3; + size -= 3; + } + // 11110xxx -> U+10000..U+10FFFF + else if (static_cast(lead - 0xF0) < 0x08 && size >= 4 && (data[1] & 0xc0) == 0x80 && (data[2] & 0xc0) == 0x80 && (data[3] & 0xc0) == 0x80) + { + result = Traits::high(result, ((lead & ~0xF0) << 18) | ((data[1] & utf8_byte_mask) << 12) | ((data[2] & utf8_byte_mask) << 6) | (data[3] & utf8_byte_mask)); + data += 4; + size -= 4; + } + // 10xxxxxx or 11111xxx -> invalid + else + { + data += 1; + size -= 1; + } + } + + return result; + } + }; + + template struct utf16_decoder + { + typedef uint16_t type; + + template static inline typename Traits::value_type process(const uint16_t* data, size_t size, typename Traits::value_type result, Traits) + { + while (size) + { + uint16_t lead = opt_swap::value ? endian_swap(*data) : *data; + + // U+0000..U+D7FF + if (lead < 0xD800) + { + result = Traits::low(result, lead); + data += 1; + size -= 1; + } + // U+E000..U+FFFF + else if (static_cast(lead - 0xE000) < 0x2000) + { + result = Traits::low(result, lead); + data += 1; + size -= 1; + } + // surrogate pair lead + else if (static_cast(lead - 0xD800) < 0x400 && size >= 2) + { + uint16_t next = opt_swap::value ? endian_swap(data[1]) : data[1]; + + if (static_cast(next - 0xDC00) < 0x400) + { + result = Traits::high(result, 0x10000 + ((lead & 0x3ff) << 10) + (next & 0x3ff)); + data += 2; + size -= 2; + } + else + { + data += 1; + size -= 1; + } + } + else + { + data += 1; + size -= 1; + } + } + + return result; + } + }; + + template struct utf32_decoder + { + typedef uint32_t type; + + template static inline typename Traits::value_type process(const uint32_t* data, size_t size, typename Traits::value_type result, Traits) + { + while (size) + { + uint32_t lead = opt_swap::value ? endian_swap(*data) : *data; + + // U+0000..U+FFFF + if (lead < 0x10000) + { + result = Traits::low(result, lead); + data += 1; + size -= 1; + } + // U+10000..U+10FFFF + else + { + result = Traits::high(result, lead); + data += 1; + size -= 1; + } + } + + return result; + } + }; + + struct latin1_decoder + { + typedef uint8_t type; + + template static inline typename Traits::value_type process(const uint8_t* data, size_t size, typename Traits::value_type result, Traits) + { + while (size) + { + result = Traits::low(result, *data); + data += 1; + size -= 1; + } + + return result; + } + }; + + template struct wchar_selector; + + template <> struct wchar_selector<2> + { + typedef uint16_t type; + typedef utf16_counter counter; + typedef utf16_writer writer; + typedef utf16_decoder decoder; + }; + + template <> struct wchar_selector<4> + { + typedef uint32_t type; + typedef utf32_counter counter; + typedef utf32_writer writer; + typedef utf32_decoder decoder; + }; + + typedef wchar_selector::counter wchar_counter; + typedef wchar_selector::writer wchar_writer; + + struct wchar_decoder + { + typedef wchar_t type; + + template static inline typename Traits::value_type process(const wchar_t* data, size_t size, typename Traits::value_type result, Traits traits) + { + typedef wchar_selector::decoder decoder; + + return decoder::process(reinterpret_cast(data), size, result, traits); + } + }; + +#ifdef PUGIXML_WCHAR_MODE + PUGI__FN void convert_wchar_endian_swap(wchar_t* result, const wchar_t* data, size_t length) + { + for (size_t i = 0; i < length; ++i) + result[i] = static_cast(endian_swap(static_cast::type>(data[i]))); + } +#endif +PUGI__NS_END + +PUGI__NS_BEGIN + enum chartype_t + { + ct_parse_pcdata = 1, // \0, &, \r, < + ct_parse_attr = 2, // \0, &, \r, ', " + ct_parse_attr_ws = 4, // \0, &, \r, ', ", \n, tab + ct_space = 8, // \r, \n, space, tab + ct_parse_cdata = 16, // \0, ], >, \r + ct_parse_comment = 32, // \0, -, >, \r + ct_symbol = 64, // Any symbol > 127, a-z, A-Z, 0-9, _, :, -, . + ct_start_symbol = 128 // Any symbol > 127, a-z, A-Z, _, : + }; + + static const unsigned char chartype_table[256] = + { + 55, 0, 0, 0, 0, 0, 0, 0, 0, 12, 12, 0, 0, 63, 0, 0, // 0-15 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 16-31 + 8, 0, 6, 0, 0, 0, 7, 6, 0, 0, 0, 0, 0, 96, 64, 0, // 32-47 + 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 192, 0, 1, 0, 48, 0, // 48-63 + 0, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, // 64-79 + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 0, 0, 16, 0, 192, // 80-95 + 0, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, // 96-111 + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 0, 0, 0, 0, 0, // 112-127 + + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, // 128+ + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, + 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192, 192 + }; + + enum chartypex_t + { + ctx_special_pcdata = 1, // Any symbol >= 0 and < 32 (except \t, \r, \n), &, <, > + ctx_special_attr = 2, // Any symbol >= 0 and < 32, &, <, ", ' + ctx_start_symbol = 4, // Any symbol > 127, a-z, A-Z, _ + ctx_digit = 8, // 0-9 + ctx_symbol = 16 // Any symbol > 127, a-z, A-Z, 0-9, _, -, . + }; + + static const unsigned char chartypex_table[256] = + { + 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, // 0-15 + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 16-31 + 0, 0, 2, 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 16, 16, 0, // 32-47 + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 0, 0, 3, 0, 1, 0, // 48-63 + + 0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, // 64-79 + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 0, 0, 0, 0, 20, // 80-95 + 0, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, // 96-111 + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 0, 0, 0, 0, 0, // 112-127 + + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, // 128+ + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20 + }; + +#ifdef PUGIXML_WCHAR_MODE + #define PUGI__IS_CHARTYPE_IMPL(c, ct, table) ((static_cast(c) < 128 ? table[static_cast(c)] : table[128]) & (ct)) +#else + #define PUGI__IS_CHARTYPE_IMPL(c, ct, table) (table[static_cast(c)] & (ct)) +#endif + + #define PUGI__IS_CHARTYPE(c, ct) PUGI__IS_CHARTYPE_IMPL(c, ct, chartype_table) + #define PUGI__IS_CHARTYPEX(c, ct) PUGI__IS_CHARTYPE_IMPL(c, ct, chartypex_table) + + PUGI__FN bool is_little_endian() + { + unsigned int ui = 1; + + return *reinterpret_cast(&ui) == 1; + } + + PUGI__FN xml_encoding get_wchar_encoding() + { + PUGI__STATIC_ASSERT(sizeof(wchar_t) == 2 || sizeof(wchar_t) == 4); + + if (sizeof(wchar_t) == 2) + return is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + else + return is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + } + + PUGI__FN bool parse_declaration_encoding(const uint8_t* data, size_t size, const uint8_t*& out_encoding, size_t& out_length) + { + #define PUGI__SCANCHAR(ch) { if (offset >= size || data[offset] != ch) return false; offset++; } + #define PUGI__SCANCHARTYPE(ct) { while (offset < size && PUGI__IS_CHARTYPE(data[offset], ct)) offset++; } + + // check if we have a non-empty XML declaration + if (size < 6 || !((data[0] == '<') & (data[1] == '?') & (data[2] == 'x') & (data[3] == 'm') & (data[4] == 'l') && PUGI__IS_CHARTYPE(data[5], ct_space))) + return false; + + // scan XML declaration until the encoding field + for (size_t i = 6; i + 1 < size; ++i) + { + // declaration can not contain ? in quoted values + if (data[i] == '?') + return false; + + if (data[i] == 'e' && data[i + 1] == 'n') + { + size_t offset = i; + + // encoding follows the version field which can't contain 'en' so this has to be the encoding if XML is well formed + PUGI__SCANCHAR('e'); PUGI__SCANCHAR('n'); PUGI__SCANCHAR('c'); PUGI__SCANCHAR('o'); + PUGI__SCANCHAR('d'); PUGI__SCANCHAR('i'); PUGI__SCANCHAR('n'); PUGI__SCANCHAR('g'); + + // S? = S? + PUGI__SCANCHARTYPE(ct_space); + PUGI__SCANCHAR('='); + PUGI__SCANCHARTYPE(ct_space); + + // the only two valid delimiters are ' and " + uint8_t delimiter = (offset < size && data[offset] == '"') ? '"' : '\''; + + PUGI__SCANCHAR(delimiter); + + size_t start = offset; + + out_encoding = data + offset; + + PUGI__SCANCHARTYPE(ct_symbol); + + out_length = offset - start; + + PUGI__SCANCHAR(delimiter); + + return true; + } + } + + return false; + + #undef PUGI__SCANCHAR + #undef PUGI__SCANCHARTYPE + } + + PUGI__FN xml_encoding guess_buffer_encoding(const uint8_t* data, size_t size) + { + // skip encoding autodetection if input buffer is too small + if (size < 4) return encoding_utf8; + + uint8_t d0 = data[0], d1 = data[1], d2 = data[2], d3 = data[3]; + + // look for BOM in first few bytes + if (d0 == 0 && d1 == 0 && d2 == 0xfe && d3 == 0xff) return encoding_utf32_be; + if (d0 == 0xff && d1 == 0xfe && d2 == 0 && d3 == 0) return encoding_utf32_le; + if (d0 == 0xfe && d1 == 0xff) return encoding_utf16_be; + if (d0 == 0xff && d1 == 0xfe) return encoding_utf16_le; + if (d0 == 0xef && d1 == 0xbb && d2 == 0xbf) return encoding_utf8; + + // look for <, (contents); + + return guess_buffer_encoding(data, size); + } + + PUGI__FN bool get_mutable_buffer(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, bool is_mutable) + { + size_t length = size / sizeof(char_t); + + if (is_mutable) + { + out_buffer = static_cast(const_cast(contents)); + out_length = length; + } + else + { + char_t* buffer = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!buffer) return false; + + if (contents) + memcpy(buffer, contents, length * sizeof(char_t)); + else + assert(length == 0); + + buffer[length] = 0; + + out_buffer = buffer; + out_length = length + 1; + } + + return true; + } + +#ifdef PUGIXML_WCHAR_MODE + PUGI__FN bool need_endian_swap_utf(xml_encoding le, xml_encoding re) + { + return (le == encoding_utf16_be && re == encoding_utf16_le) || (le == encoding_utf16_le && re == encoding_utf16_be) || + (le == encoding_utf32_be && re == encoding_utf32_le) || (le == encoding_utf32_le && re == encoding_utf32_be); + } + + PUGI__FN bool convert_buffer_endian_swap(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, bool is_mutable) + { + const char_t* data = static_cast(contents); + size_t length = size / sizeof(char_t); + + if (is_mutable) + { + char_t* buffer = const_cast(data); + + convert_wchar_endian_swap(buffer, data, length); + + out_buffer = buffer; + out_length = length; + } + else + { + char_t* buffer = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!buffer) return false; + + convert_wchar_endian_swap(buffer, data, length); + buffer[length] = 0; + + out_buffer = buffer; + out_length = length + 1; + } + + return true; + } + + template PUGI__FN bool convert_buffer_generic(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, D) + { + const typename D::type* data = static_cast(contents); + size_t data_length = size / sizeof(typename D::type); + + // first pass: get length in wchar_t units + size_t length = D::process(data, data_length, 0, wchar_counter()); + + // allocate buffer of suitable length + char_t* buffer = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!buffer) return false; + + // second pass: convert utf16 input to wchar_t + wchar_writer::value_type obegin = reinterpret_cast(buffer); + wchar_writer::value_type oend = D::process(data, data_length, obegin, wchar_writer()); + + assert(oend == obegin + length); + *oend = 0; + + out_buffer = buffer; + out_length = length + 1; + + return true; + } + + PUGI__FN bool convert_buffer(char_t*& out_buffer, size_t& out_length, xml_encoding encoding, const void* contents, size_t size, bool is_mutable) + { + // get native encoding + xml_encoding wchar_encoding = get_wchar_encoding(); + + // fast path: no conversion required + if (encoding == wchar_encoding) + return get_mutable_buffer(out_buffer, out_length, contents, size, is_mutable); + + // only endian-swapping is required + if (need_endian_swap_utf(encoding, wchar_encoding)) + return convert_buffer_endian_swap(out_buffer, out_length, contents, size, is_mutable); + + // source encoding is utf8 + if (encoding == encoding_utf8) + return convert_buffer_generic(out_buffer, out_length, contents, size, utf8_decoder()); + + // source encoding is utf16 + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + return (native_encoding == encoding) ? + convert_buffer_generic(out_buffer, out_length, contents, size, utf16_decoder()) : + convert_buffer_generic(out_buffer, out_length, contents, size, utf16_decoder()); + } + + // source encoding is utf32 + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + return (native_encoding == encoding) ? + convert_buffer_generic(out_buffer, out_length, contents, size, utf32_decoder()) : + convert_buffer_generic(out_buffer, out_length, contents, size, utf32_decoder()); + } + + // source encoding is latin1 + if (encoding == encoding_latin1) + return convert_buffer_generic(out_buffer, out_length, contents, size, latin1_decoder()); + + assert(false && "Invalid encoding"); // unreachable + return false; + } +#else + template PUGI__FN bool convert_buffer_generic(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, D) + { + const typename D::type* data = static_cast(contents); + size_t data_length = size / sizeof(typename D::type); + + // first pass: get length in utf8 units + size_t length = D::process(data, data_length, 0, utf8_counter()); + + // allocate buffer of suitable length + char_t* buffer = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!buffer) return false; + + // second pass: convert utf16 input to utf8 + uint8_t* obegin = reinterpret_cast(buffer); + uint8_t* oend = D::process(data, data_length, obegin, utf8_writer()); + + assert(oend == obegin + length); + *oend = 0; + + out_buffer = buffer; + out_length = length + 1; + + return true; + } + + PUGI__FN size_t get_latin1_7bit_prefix_length(const uint8_t* data, size_t size) + { + for (size_t i = 0; i < size; ++i) + if (data[i] > 127) + return i; + + return size; + } + + PUGI__FN bool convert_buffer_latin1(char_t*& out_buffer, size_t& out_length, const void* contents, size_t size, bool is_mutable) + { + const uint8_t* data = static_cast(contents); + size_t data_length = size; + + // get size of prefix that does not need utf8 conversion + size_t prefix_length = get_latin1_7bit_prefix_length(data, data_length); + assert(prefix_length <= data_length); + + const uint8_t* postfix = data + prefix_length; + size_t postfix_length = data_length - prefix_length; + + // if no conversion is needed, just return the original buffer + if (postfix_length == 0) return get_mutable_buffer(out_buffer, out_length, contents, size, is_mutable); + + // first pass: get length in utf8 units + size_t length = prefix_length + latin1_decoder::process(postfix, postfix_length, 0, utf8_counter()); + + // allocate buffer of suitable length + char_t* buffer = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!buffer) return false; + + // second pass: convert latin1 input to utf8 + memcpy(buffer, data, prefix_length); + + uint8_t* obegin = reinterpret_cast(buffer); + uint8_t* oend = latin1_decoder::process(postfix, postfix_length, obegin + prefix_length, utf8_writer()); + + assert(oend == obegin + length); + *oend = 0; + + out_buffer = buffer; + out_length = length + 1; + + return true; + } + + PUGI__FN bool convert_buffer(char_t*& out_buffer, size_t& out_length, xml_encoding encoding, const void* contents, size_t size, bool is_mutable) + { + // fast path: no conversion required + if (encoding == encoding_utf8) + return get_mutable_buffer(out_buffer, out_length, contents, size, is_mutable); + + // source encoding is utf16 + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + return (native_encoding == encoding) ? + convert_buffer_generic(out_buffer, out_length, contents, size, utf16_decoder()) : + convert_buffer_generic(out_buffer, out_length, contents, size, utf16_decoder()); + } + + // source encoding is utf32 + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + return (native_encoding == encoding) ? + convert_buffer_generic(out_buffer, out_length, contents, size, utf32_decoder()) : + convert_buffer_generic(out_buffer, out_length, contents, size, utf32_decoder()); + } + + // source encoding is latin1 + if (encoding == encoding_latin1) + return convert_buffer_latin1(out_buffer, out_length, contents, size, is_mutable); + + assert(false && "Invalid encoding"); // unreachable + return false; + } +#endif + + PUGI__FN size_t as_utf8_begin(const wchar_t* str, size_t length) + { + // get length in utf8 characters + return wchar_decoder::process(str, length, 0, utf8_counter()); + } + + PUGI__FN void as_utf8_end(char* buffer, size_t size, const wchar_t* str, size_t length) + { + // convert to utf8 + uint8_t* begin = reinterpret_cast(buffer); + uint8_t* end = wchar_decoder::process(str, length, begin, utf8_writer()); + + assert(begin + size == end); + (void)!end; + (void)!size; + } + +#ifndef PUGIXML_NO_STL + PUGI__FN std::string as_utf8_impl(const wchar_t* str, size_t length) + { + // first pass: get length in utf8 characters + size_t size = as_utf8_begin(str, length); + + // allocate resulting string + std::string result; + result.resize(size); + + // second pass: convert to utf8 + if (size > 0) as_utf8_end(&result[0], size, str, length); + + return result; + } + + PUGI__FN std::basic_string as_wide_impl(const char* str, size_t size) + { + const uint8_t* data = reinterpret_cast(str); + + // first pass: get length in wchar_t units + size_t length = utf8_decoder::process(data, size, 0, wchar_counter()); + + // allocate resulting string + std::basic_string result; + result.resize(length); + + // second pass: convert to wchar_t + if (length > 0) + { + wchar_writer::value_type begin = reinterpret_cast(&result[0]); + wchar_writer::value_type end = utf8_decoder::process(data, size, begin, wchar_writer()); + + assert(begin + length == end); + (void)!end; + } + + return result; + } +#endif + + template + inline bool strcpy_insitu_allow(size_t length, const Header& header, uintptr_t header_mask, char_t* target) + { + // never reuse shared memory + if (header & xml_memory_page_contents_shared_mask) return false; + + size_t target_length = strlength(target); + + // always reuse document buffer memory if possible + if ((header & header_mask) == 0) return target_length >= length; + + // reuse heap memory if waste is not too great + const size_t reuse_threshold = 32; + + return target_length >= length && (target_length < reuse_threshold || target_length - length < target_length / 2); + } + + template + PUGI__FN bool strcpy_insitu(String& dest, Header& header, uintptr_t header_mask, const char_t* source, size_t source_length) + { + if (source_length == 0) + { + // empty string and null pointer are equivalent, so just deallocate old memory + xml_allocator* alloc = PUGI__GETPAGE_IMPL(header)->allocator; + + if (header & header_mask) alloc->deallocate_string(dest); + + // mark the string as not allocated + dest = 0; + header &= ~header_mask; + + return true; + } + else if (dest && strcpy_insitu_allow(source_length, header, header_mask, dest)) + { + // we can reuse old buffer, so just copy the new data (including zero terminator) + memcpy(dest, source, source_length * sizeof(char_t)); + dest[source_length] = 0; + + return true; + } + else + { + xml_allocator* alloc = PUGI__GETPAGE_IMPL(header)->allocator; + + if (!alloc->reserve()) return false; + + // allocate new buffer + char_t* buf = alloc->allocate_string(source_length + 1); + if (!buf) return false; + + // copy the string (including zero terminator) + memcpy(buf, source, source_length * sizeof(char_t)); + buf[source_length] = 0; + + // deallocate old buffer (*after* the above to protect against overlapping memory and/or allocation failures) + if (header & header_mask) alloc->deallocate_string(dest); + + // the string is now allocated, so set the flag + dest = buf; + header |= header_mask; + + return true; + } + } + + struct gap + { + char_t* end; + size_t size; + + gap(): end(0), size(0) + { + } + + // Push new gap, move s count bytes further (skipping the gap). + // Collapse previous gap. + void push(char_t*& s, size_t count) + { + if (end) // there was a gap already; collapse it + { + // Move [old_gap_end, new_gap_start) to [old_gap_start, ...) + assert(s >= end); + memmove(end - size, end, reinterpret_cast(s) - reinterpret_cast(end)); + } + + s += count; // end of current gap + + // "merge" two gaps + end = s; + size += count; + } + + // Collapse all gaps, return past-the-end pointer + char_t* flush(char_t* s) + { + if (end) + { + // Move [old_gap_end, current_pos) to [old_gap_start, ...) + assert(s >= end); + memmove(end - size, end, reinterpret_cast(s) - reinterpret_cast(end)); + + return s - size; + } + else return s; + } + }; + + PUGI__FN char_t* strconv_escape(char_t* s, gap& g) + { + char_t* stre = s + 1; + + switch (*stre) + { + case '#': // &#... + { + unsigned int ucsc = 0; + + if (stre[1] == 'x') // &#x... (hex code) + { + stre += 2; + + char_t ch = *stre; + + if (ch == ';') return stre; + + for (;;) + { + if (static_cast(ch - '0') <= 9) + ucsc = 16 * ucsc + (ch - '0'); + else if (static_cast((ch | ' ') - 'a') <= 5) + ucsc = 16 * ucsc + ((ch | ' ') - 'a' + 10); + else if (ch == ';') + break; + else // cancel + return stre; + + ch = *++stre; + } + + ++stre; + } + else // &#... (dec code) + { + char_t ch = *++stre; + + if (ch == ';') return stre; + + for (;;) + { + if (static_cast(ch - '0') <= 9) + ucsc = 10 * ucsc + (ch - '0'); + else if (ch == ';') + break; + else // cancel + return stre; + + ch = *++stre; + } + + ++stre; + } + + #ifdef PUGIXML_WCHAR_MODE + s = reinterpret_cast(wchar_writer::any(reinterpret_cast(s), ucsc)); + #else + s = reinterpret_cast(utf8_writer::any(reinterpret_cast(s), ucsc)); + #endif + + g.push(s, stre - s); + return stre; + } + + case 'a': // &a + { + ++stre; + + if (*stre == 'm') // &am + { + if (*++stre == 'p' && *++stre == ';') // & + { + *s++ = '&'; + ++stre; + + g.push(s, stre - s); + return stre; + } + } + else if (*stre == 'p') // &ap + { + if (*++stre == 'o' && *++stre == 's' && *++stre == ';') // ' + { + *s++ = '\''; + ++stre; + + g.push(s, stre - s); + return stre; + } + } + break; + } + + case 'g': // &g + { + if (*++stre == 't' && *++stre == ';') // > + { + *s++ = '>'; + ++stre; + + g.push(s, stre - s); + return stre; + } + break; + } + + case 'l': // &l + { + if (*++stre == 't' && *++stre == ';') // < + { + *s++ = '<'; + ++stre; + + g.push(s, stre - s); + return stre; + } + break; + } + + case 'q': // &q + { + if (*++stre == 'u' && *++stre == 'o' && *++stre == 't' && *++stre == ';') // " + { + *s++ = '"'; + ++stre; + + g.push(s, stre - s); + return stre; + } + break; + } + + default: + break; + } + + return stre; + } + + // Parser utilities + #define PUGI__ENDSWITH(c, e) ((c) == (e) || ((c) == 0 && endch == (e))) + #define PUGI__SKIPWS() { while (PUGI__IS_CHARTYPE(*s, ct_space)) ++s; } + #define PUGI__OPTSET(OPT) ( optmsk & (OPT) ) + #define PUGI__PUSHNODE(TYPE) { cursor = append_new_node(cursor, *alloc, TYPE); if (!cursor) PUGI__THROW_ERROR(status_out_of_memory, s); } + #define PUGI__POPNODE() { cursor = cursor->parent; } + #define PUGI__SCANFOR(X) { while (*s != 0 && !(X)) ++s; } + #define PUGI__SCANWHILE(X) { while (X) ++s; } + #define PUGI__SCANWHILE_UNROLL(X) { for (;;) { char_t ss = s[0]; if (PUGI__UNLIKELY(!(X))) { break; } ss = s[1]; if (PUGI__UNLIKELY(!(X))) { s += 1; break; } ss = s[2]; if (PUGI__UNLIKELY(!(X))) { s += 2; break; } ss = s[3]; if (PUGI__UNLIKELY(!(X))) { s += 3; break; } s += 4; } } + #define PUGI__ENDSEG() { ch = *s; *s = 0; ++s; } + #define PUGI__THROW_ERROR(err, m) return error_offset = m, error_status = err, static_cast(0) + #define PUGI__CHECK_ERROR(err, m) { if (*s == 0) PUGI__THROW_ERROR(err, m); } + + PUGI__FN char_t* strconv_comment(char_t* s, char_t endch) + { + gap g; + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_comment)); + + if (*s == '\r') // Either a single 0x0d or 0x0d 0x0a pair + { + *s++ = '\n'; // replace first one with 0x0a + + if (*s == '\n') g.push(s, 1); + } + else if (s[0] == '-' && s[1] == '-' && PUGI__ENDSWITH(s[2], '>')) // comment ends here + { + *g.flush(s) = 0; + + return s + (s[2] == '>' ? 3 : 2); + } + else if (*s == 0) + { + return 0; + } + else ++s; + } + } + + PUGI__FN char_t* strconv_cdata(char_t* s, char_t endch) + { + gap g; + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_cdata)); + + if (*s == '\r') // Either a single 0x0d or 0x0d 0x0a pair + { + *s++ = '\n'; // replace first one with 0x0a + + if (*s == '\n') g.push(s, 1); + } + else if (s[0] == ']' && s[1] == ']' && PUGI__ENDSWITH(s[2], '>')) // CDATA ends here + { + *g.flush(s) = 0; + + return s + 1; + } + else if (*s == 0) + { + return 0; + } + else ++s; + } + } + + typedef char_t* (*strconv_pcdata_t)(char_t*); + + template struct strconv_pcdata_impl + { + static char_t* parse(char_t* s) + { + gap g; + + char_t* begin = s; + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_pcdata)); + + if (*s == '<') // PCDATA ends here + { + char_t* end = g.flush(s); + + if (opt_trim::value) + while (end > begin && PUGI__IS_CHARTYPE(end[-1], ct_space)) + --end; + + *end = 0; + + return s + 1; + } + else if (opt_eol::value && *s == '\r') // Either a single 0x0d or 0x0d 0x0a pair + { + *s++ = '\n'; // replace first one with 0x0a + + if (*s == '\n') g.push(s, 1); + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (*s == 0) + { + char_t* end = g.flush(s); + + if (opt_trim::value) + while (end > begin && PUGI__IS_CHARTYPE(end[-1], ct_space)) + --end; + + *end = 0; + + return s; + } + else ++s; + } + } + }; + + PUGI__FN strconv_pcdata_t get_strconv_pcdata(unsigned int optmask) + { + PUGI__STATIC_ASSERT(parse_escapes == 0x10 && parse_eol == 0x20 && parse_trim_pcdata == 0x0800); + + switch (((optmask >> 4) & 3) | ((optmask >> 9) & 4)) // get bitmask for flags (trim eol escapes); this simultaneously checks 3 options from assertion above + { + case 0: return strconv_pcdata_impl::parse; + case 1: return strconv_pcdata_impl::parse; + case 2: return strconv_pcdata_impl::parse; + case 3: return strconv_pcdata_impl::parse; + case 4: return strconv_pcdata_impl::parse; + case 5: return strconv_pcdata_impl::parse; + case 6: return strconv_pcdata_impl::parse; + case 7: return strconv_pcdata_impl::parse; + default: assert(false); return 0; // unreachable + } + } + + typedef char_t* (*strconv_attribute_t)(char_t*, char_t); + + template struct strconv_attribute_impl + { + static char_t* parse_wnorm(char_t* s, char_t end_quote) + { + gap g; + + // trim leading whitespaces + if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + char_t* str = s; + + do ++str; + while (PUGI__IS_CHARTYPE(*str, ct_space)); + + g.push(s, str - s); + } + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_attr_ws | ct_space)); + + if (*s == end_quote) + { + char_t* str = g.flush(s); + + do *str-- = 0; + while (PUGI__IS_CHARTYPE(*str, ct_space)); + + return s + 1; + } + else if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + *s++ = ' '; + + if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + char_t* str = s + 1; + while (PUGI__IS_CHARTYPE(*str, ct_space)) ++str; + + g.push(s, str - s); + } + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + + static char_t* parse_wconv(char_t* s, char_t end_quote) + { + gap g; + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_attr_ws)); + + if (*s == end_quote) + { + *g.flush(s) = 0; + + return s + 1; + } + else if (PUGI__IS_CHARTYPE(*s, ct_space)) + { + if (*s == '\r') + { + *s++ = ' '; + + if (*s == '\n') g.push(s, 1); + } + else *s++ = ' '; + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + + static char_t* parse_eol(char_t* s, char_t end_quote) + { + gap g; + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_attr)); + + if (*s == end_quote) + { + *g.flush(s) = 0; + + return s + 1; + } + else if (*s == '\r') + { + *s++ = '\n'; + + if (*s == '\n') g.push(s, 1); + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + + static char_t* parse_simple(char_t* s, char_t end_quote) + { + gap g; + + while (true) + { + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPE(ss, ct_parse_attr)); + + if (*s == end_quote) + { + *g.flush(s) = 0; + + return s + 1; + } + else if (opt_escape::value && *s == '&') + { + s = strconv_escape(s, g); + } + else if (!*s) + { + return 0; + } + else ++s; + } + } + }; + + PUGI__FN strconv_attribute_t get_strconv_attribute(unsigned int optmask) + { + PUGI__STATIC_ASSERT(parse_escapes == 0x10 && parse_eol == 0x20 && parse_wconv_attribute == 0x40 && parse_wnorm_attribute == 0x80); + + switch ((optmask >> 4) & 15) // get bitmask for flags (wnorm wconv eol escapes); this simultaneously checks 4 options from assertion above + { + case 0: return strconv_attribute_impl::parse_simple; + case 1: return strconv_attribute_impl::parse_simple; + case 2: return strconv_attribute_impl::parse_eol; + case 3: return strconv_attribute_impl::parse_eol; + case 4: return strconv_attribute_impl::parse_wconv; + case 5: return strconv_attribute_impl::parse_wconv; + case 6: return strconv_attribute_impl::parse_wconv; + case 7: return strconv_attribute_impl::parse_wconv; + case 8: return strconv_attribute_impl::parse_wnorm; + case 9: return strconv_attribute_impl::parse_wnorm; + case 10: return strconv_attribute_impl::parse_wnorm; + case 11: return strconv_attribute_impl::parse_wnorm; + case 12: return strconv_attribute_impl::parse_wnorm; + case 13: return strconv_attribute_impl::parse_wnorm; + case 14: return strconv_attribute_impl::parse_wnorm; + case 15: return strconv_attribute_impl::parse_wnorm; + default: assert(false); return 0; // unreachable + } + } + + inline xml_parse_result make_parse_result(xml_parse_status status, ptrdiff_t offset = 0) + { + xml_parse_result result; + result.status = status; + result.offset = offset; + + return result; + } + + struct xml_parser + { + xml_allocator* alloc; + char_t* error_offset; + xml_parse_status error_status; + + xml_parser(xml_allocator* alloc_): alloc(alloc_), error_offset(0), error_status(status_ok) + { + } + + // DOCTYPE consists of nested sections of the following possible types: + // , , "...", '...' + // + // + // First group can not contain nested groups + // Second group can contain nested groups of the same type + // Third group can contain all other groups + char_t* parse_doctype_primitive(char_t* s) + { + if (*s == '"' || *s == '\'') + { + // quoted string + char_t ch = *s++; + PUGI__SCANFOR(*s == ch); + if (!*s) PUGI__THROW_ERROR(status_bad_doctype, s); + + s++; + } + else if (s[0] == '<' && s[1] == '?') + { + // + s += 2; + PUGI__SCANFOR(s[0] == '?' && s[1] == '>'); // no need for ENDSWITH because ?> can't terminate proper doctype + if (!*s) PUGI__THROW_ERROR(status_bad_doctype, s); + + s += 2; + } + else if (s[0] == '<' && s[1] == '!' && s[2] == '-' && s[3] == '-') + { + s += 4; + PUGI__SCANFOR(s[0] == '-' && s[1] == '-' && s[2] == '>'); // no need for ENDSWITH because --> can't terminate proper doctype + if (!*s) PUGI__THROW_ERROR(status_bad_doctype, s); + + s += 3; + } + else PUGI__THROW_ERROR(status_bad_doctype, s); + + return s; + } + + char_t* parse_doctype_ignore(char_t* s) + { + size_t depth = 0; + + assert(s[0] == '<' && s[1] == '!' && s[2] == '['); + s += 3; + + while (*s) + { + if (s[0] == '<' && s[1] == '!' && s[2] == '[') + { + // nested ignore section + s += 3; + depth++; + } + else if (s[0] == ']' && s[1] == ']' && s[2] == '>') + { + // ignore section end + s += 3; + + if (depth == 0) + return s; + + depth--; + } + else s++; + } + + PUGI__THROW_ERROR(status_bad_doctype, s); + } + + char_t* parse_doctype_group(char_t* s, char_t endch) + { + size_t depth = 0; + + assert((s[0] == '<' || s[0] == 0) && s[1] == '!'); + s += 2; + + while (*s) + { + if (s[0] == '<' && s[1] == '!' && s[2] != '-') + { + if (s[2] == '[') + { + // ignore + s = parse_doctype_ignore(s); + if (!s) return s; + } + else + { + // some control group + s += 2; + depth++; + } + } + else if (s[0] == '<' || s[0] == '"' || s[0] == '\'') + { + // unknown tag (forbidden), or some primitive group + s = parse_doctype_primitive(s); + if (!s) return s; + } + else if (*s == '>') + { + if (depth == 0) + return s; + + depth--; + s++; + } + else s++; + } + + if (depth != 0 || endch != '>') PUGI__THROW_ERROR(status_bad_doctype, s); + + return s; + } + + char_t* parse_exclamation(char_t* s, xml_node_struct* cursor, unsigned int optmsk, char_t endch) + { + // parse node contents, starting with exclamation mark + ++s; + + if (*s == '-') // 'value = s; // Save the offset. + } + + if (PUGI__OPTSET(parse_eol) && PUGI__OPTSET(parse_comments)) + { + s = strconv_comment(s, endch); + + if (!s) PUGI__THROW_ERROR(status_bad_comment, cursor->value); + } + else + { + // Scan for terminating '-->'. + PUGI__SCANFOR(s[0] == '-' && s[1] == '-' && PUGI__ENDSWITH(s[2], '>')); + PUGI__CHECK_ERROR(status_bad_comment, s); + + if (PUGI__OPTSET(parse_comments)) + *s = 0; // Zero-terminate this segment at the first terminating '-'. + + s += (s[2] == '>' ? 3 : 2); // Step over the '\0->'. + } + } + else PUGI__THROW_ERROR(status_bad_comment, s); + } + else if (*s == '[') + { + // 'value = s; // Save the offset. + + if (PUGI__OPTSET(parse_eol)) + { + s = strconv_cdata(s, endch); + + if (!s) PUGI__THROW_ERROR(status_bad_cdata, cursor->value); + } + else + { + // Scan for terminating ']]>'. + PUGI__SCANFOR(s[0] == ']' && s[1] == ']' && PUGI__ENDSWITH(s[2], '>')); + PUGI__CHECK_ERROR(status_bad_cdata, s); + + *s++ = 0; // Zero-terminate this segment. + } + } + else // Flagged for discard, but we still have to scan for the terminator. + { + // Scan for terminating ']]>'. + PUGI__SCANFOR(s[0] == ']' && s[1] == ']' && PUGI__ENDSWITH(s[2], '>')); + PUGI__CHECK_ERROR(status_bad_cdata, s); + + ++s; + } + + s += (s[1] == '>' ? 2 : 1); // Step over the last ']>'. + } + else PUGI__THROW_ERROR(status_bad_cdata, s); + } + else if (s[0] == 'D' && s[1] == 'O' && s[2] == 'C' && s[3] == 'T' && s[4] == 'Y' && s[5] == 'P' && PUGI__ENDSWITH(s[6], 'E')) + { + s -= 2; + + if (cursor->parent) PUGI__THROW_ERROR(status_bad_doctype, s); + + char_t* mark = s + 9; + + s = parse_doctype_group(s, endch); + if (!s) return s; + + assert((*s == 0 && endch == '>') || *s == '>'); + if (*s) *s++ = 0; + + if (PUGI__OPTSET(parse_doctype)) + { + while (PUGI__IS_CHARTYPE(*mark, ct_space)) ++mark; + + PUGI__PUSHNODE(node_doctype); + + cursor->value = mark; + } + } + else if (*s == 0 && endch == '-') PUGI__THROW_ERROR(status_bad_comment, s); + else if (*s == 0 && endch == '[') PUGI__THROW_ERROR(status_bad_cdata, s); + else PUGI__THROW_ERROR(status_unrecognized_tag, s); + + return s; + } + + char_t* parse_question(char_t* s, xml_node_struct*& ref_cursor, unsigned int optmsk, char_t endch) + { + // load into registers + xml_node_struct* cursor = ref_cursor; + char_t ch = 0; + + // parse node contents, starting with question mark + ++s; + + // read PI target + char_t* target = s; + + if (!PUGI__IS_CHARTYPE(*s, ct_start_symbol)) PUGI__THROW_ERROR(status_bad_pi, s); + + PUGI__SCANWHILE(PUGI__IS_CHARTYPE(*s, ct_symbol)); + PUGI__CHECK_ERROR(status_bad_pi, s); + + // determine node type; stricmp / strcasecmp is not portable + bool declaration = (target[0] | ' ') == 'x' && (target[1] | ' ') == 'm' && (target[2] | ' ') == 'l' && target + 3 == s; + + if (declaration ? PUGI__OPTSET(parse_declaration) : PUGI__OPTSET(parse_pi)) + { + if (declaration) + { + // disallow non top-level declarations + if (cursor->parent) PUGI__THROW_ERROR(status_bad_pi, s); + + PUGI__PUSHNODE(node_declaration); + } + else + { + PUGI__PUSHNODE(node_pi); + } + + cursor->name = target; + + PUGI__ENDSEG(); + + // parse value/attributes + if (ch == '?') + { + // empty node + if (!PUGI__ENDSWITH(*s, '>')) PUGI__THROW_ERROR(status_bad_pi, s); + s += (*s == '>'); + + PUGI__POPNODE(); + } + else if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + PUGI__SKIPWS(); + + // scan for tag end + char_t* value = s; + + PUGI__SCANFOR(s[0] == '?' && PUGI__ENDSWITH(s[1], '>')); + PUGI__CHECK_ERROR(status_bad_pi, s); + + if (declaration) + { + // replace ending ? with / so that 'element' terminates properly + *s = '/'; + + // we exit from this function with cursor at node_declaration, which is a signal to parse() to go to LOC_ATTRIBUTES + s = value; + } + else + { + // store value and step over > + cursor->value = value; + + PUGI__POPNODE(); + + PUGI__ENDSEG(); + + s += (*s == '>'); + } + } + else PUGI__THROW_ERROR(status_bad_pi, s); + } + else + { + // scan for tag end + PUGI__SCANFOR(s[0] == '?' && PUGI__ENDSWITH(s[1], '>')); + PUGI__CHECK_ERROR(status_bad_pi, s); + + s += (s[1] == '>' ? 2 : 1); + } + + // store from registers + ref_cursor = cursor; + + return s; + } + + char_t* parse_tree(char_t* s, xml_node_struct* root, unsigned int optmsk, char_t endch) + { + strconv_attribute_t strconv_attribute = get_strconv_attribute(optmsk); + strconv_pcdata_t strconv_pcdata = get_strconv_pcdata(optmsk); + + char_t ch = 0; + xml_node_struct* cursor = root; + char_t* mark = s; + + while (*s != 0) + { + if (*s == '<') + { + ++s; + + LOC_TAG: + if (PUGI__IS_CHARTYPE(*s, ct_start_symbol)) // '<#...' + { + PUGI__PUSHNODE(node_element); // Append a new node to the tree. + + cursor->name = s; + + PUGI__SCANWHILE_UNROLL(PUGI__IS_CHARTYPE(ss, ct_symbol)); // Scan for a terminator. + PUGI__ENDSEG(); // Save char in 'ch', terminate & step over. + + if (ch == '>') + { + // end of tag + } + else if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + LOC_ATTRIBUTES: + while (true) + { + PUGI__SKIPWS(); // Eat any whitespace. + + if (PUGI__IS_CHARTYPE(*s, ct_start_symbol)) // <... #... + { + xml_attribute_struct* a = append_new_attribute(cursor, *alloc); // Make space for this attribute. + if (!a) PUGI__THROW_ERROR(status_out_of_memory, s); + + a->name = s; // Save the offset. + + PUGI__SCANWHILE_UNROLL(PUGI__IS_CHARTYPE(ss, ct_symbol)); // Scan for a terminator. + PUGI__ENDSEG(); // Save char in 'ch', terminate & step over. + + if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + PUGI__SKIPWS(); // Eat any whitespace. + + ch = *s; + ++s; + } + + if (ch == '=') // '<... #=...' + { + PUGI__SKIPWS(); // Eat any whitespace. + + if (*s == '"' || *s == '\'') // '<... #="...' + { + ch = *s; // Save quote char to avoid breaking on "''" -or- '""'. + ++s; // Step over the quote. + a->value = s; // Save the offset. + + s = strconv_attribute(s, ch); + + if (!s) PUGI__THROW_ERROR(status_bad_attribute, a->value); + + // After this line the loop continues from the start; + // Whitespaces, / and > are ok, symbols and EOF are wrong, + // everything else will be detected + if (PUGI__IS_CHARTYPE(*s, ct_start_symbol)) PUGI__THROW_ERROR(status_bad_attribute, s); + } + else PUGI__THROW_ERROR(status_bad_attribute, s); + } + else PUGI__THROW_ERROR(status_bad_attribute, s); + } + else if (*s == '/') + { + ++s; + + if (*s == '>') + { + PUGI__POPNODE(); + s++; + break; + } + else if (*s == 0 && endch == '>') + { + PUGI__POPNODE(); + break; + } + else PUGI__THROW_ERROR(status_bad_start_element, s); + } + else if (*s == '>') + { + ++s; + + break; + } + else if (*s == 0 && endch == '>') + { + break; + } + else PUGI__THROW_ERROR(status_bad_start_element, s); + } + + // !!! + } + else if (ch == '/') // '<#.../' + { + if (!PUGI__ENDSWITH(*s, '>')) PUGI__THROW_ERROR(status_bad_start_element, s); + + PUGI__POPNODE(); // Pop. + + s += (*s == '>'); + } + else if (ch == 0) + { + // we stepped over null terminator, backtrack & handle closing tag + --s; + + if (endch != '>') PUGI__THROW_ERROR(status_bad_start_element, s); + } + else PUGI__THROW_ERROR(status_bad_start_element, s); + } + else if (*s == '/') + { + ++s; + + mark = s; + + char_t* name = cursor->name; + if (!name) PUGI__THROW_ERROR(status_end_element_mismatch, mark); + + while (PUGI__IS_CHARTYPE(*s, ct_symbol)) + { + if (*s++ != *name++) PUGI__THROW_ERROR(status_end_element_mismatch, mark); + } + + if (*name) + { + if (*s == 0 && name[0] == endch && name[1] == 0) PUGI__THROW_ERROR(status_bad_end_element, s); + else PUGI__THROW_ERROR(status_end_element_mismatch, mark); + } + + PUGI__POPNODE(); // Pop. + + PUGI__SKIPWS(); + + if (*s == 0) + { + if (endch != '>') PUGI__THROW_ERROR(status_bad_end_element, s); + } + else + { + if (*s != '>') PUGI__THROW_ERROR(status_bad_end_element, s); + ++s; + } + } + else if (*s == '?') // 'first_child) continue; + } + } + + if (!PUGI__OPTSET(parse_trim_pcdata)) + s = mark; + + if (cursor->parent || PUGI__OPTSET(parse_fragment)) + { + if (PUGI__OPTSET(parse_embed_pcdata) && cursor->parent && !cursor->first_child && !cursor->value) + { + cursor->value = s; // Save the offset. + } + else + { + PUGI__PUSHNODE(node_pcdata); // Append a new node on the tree. + + cursor->value = s; // Save the offset. + + PUGI__POPNODE(); // Pop since this is a standalone. + } + + s = strconv_pcdata(s); + + if (!*s) break; + } + else + { + PUGI__SCANFOR(*s == '<'); // '...<' + if (!*s) break; + + ++s; + } + + // We're after '<' + goto LOC_TAG; + } + } + + // check that last tag is closed + if (cursor != root) PUGI__THROW_ERROR(status_end_element_mismatch, s); + + return s; + } + + #ifdef PUGIXML_WCHAR_MODE + static char_t* parse_skip_bom(char_t* s) + { + unsigned int bom = 0xfeff; + return (s[0] == static_cast(bom)) ? s + 1 : s; + } + #else + static char_t* parse_skip_bom(char_t* s) + { + return (s[0] == '\xef' && s[1] == '\xbb' && s[2] == '\xbf') ? s + 3 : s; + } + #endif + + static bool has_element_node_siblings(xml_node_struct* node) + { + while (node) + { + if (PUGI__NODETYPE(node) == node_element) return true; + + node = node->next_sibling; + } + + return false; + } + + static xml_parse_result parse(char_t* buffer, size_t length, xml_document_struct* xmldoc, xml_node_struct* root, unsigned int optmsk) + { + // early-out for empty documents + if (length == 0) + return make_parse_result(PUGI__OPTSET(parse_fragment) ? status_ok : status_no_document_element); + + // get last child of the root before parsing + xml_node_struct* last_root_child = root->first_child ? root->first_child->prev_sibling_c + 0 : 0; + + // create parser on stack + xml_parser parser(static_cast(xmldoc)); + + // save last character and make buffer zero-terminated (speeds up parsing) + char_t endch = buffer[length - 1]; + buffer[length - 1] = 0; + + // skip BOM to make sure it does not end up as part of parse output + char_t* buffer_data = parse_skip_bom(buffer); + + // perform actual parsing + parser.parse_tree(buffer_data, root, optmsk, endch); + + xml_parse_result result = make_parse_result(parser.error_status, parser.error_offset ? parser.error_offset - buffer : 0); + assert(result.offset >= 0 && static_cast(result.offset) <= length); + + if (result) + { + // since we removed last character, we have to handle the only possible false positive (stray <) + if (endch == '<') + return make_parse_result(status_unrecognized_tag, length - 1); + + // check if there are any element nodes parsed + xml_node_struct* first_root_child_parsed = last_root_child ? last_root_child->next_sibling + 0 : root->first_child+ 0; + + if (!PUGI__OPTSET(parse_fragment) && !has_element_node_siblings(first_root_child_parsed)) + return make_parse_result(status_no_document_element, length - 1); + } + else + { + // roll back offset if it occurs on a null terminator in the source buffer + if (result.offset > 0 && static_cast(result.offset) == length - 1 && endch == 0) + result.offset--; + } + + return result; + } + }; + + // Output facilities + PUGI__FN xml_encoding get_write_native_encoding() + { + #ifdef PUGIXML_WCHAR_MODE + return get_wchar_encoding(); + #else + return encoding_utf8; + #endif + } + + PUGI__FN xml_encoding get_write_encoding(xml_encoding encoding) + { + // replace wchar encoding with utf implementation + if (encoding == encoding_wchar) return get_wchar_encoding(); + + // replace utf16 encoding with utf16 with specific endianness + if (encoding == encoding_utf16) return is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + // replace utf32 encoding with utf32 with specific endianness + if (encoding == encoding_utf32) return is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + // only do autodetection if no explicit encoding is requested + if (encoding != encoding_auto) return encoding; + + // assume utf8 encoding + return encoding_utf8; + } + + template PUGI__FN size_t convert_buffer_output_generic(typename T::value_type dest, const char_t* data, size_t length, D, T) + { + PUGI__STATIC_ASSERT(sizeof(char_t) == sizeof(typename D::type)); + + typename T::value_type end = D::process(reinterpret_cast(data), length, dest, T()); + + return static_cast(end - dest) * sizeof(*dest); + } + + template PUGI__FN size_t convert_buffer_output_generic(typename T::value_type dest, const char_t* data, size_t length, D, T, bool opt_swap) + { + PUGI__STATIC_ASSERT(sizeof(char_t) == sizeof(typename D::type)); + + typename T::value_type end = D::process(reinterpret_cast(data), length, dest, T()); + + if (opt_swap) + { + for (typename T::value_type i = dest; i != end; ++i) + *i = endian_swap(*i); + } + + return static_cast(end - dest) * sizeof(*dest); + } + +#ifdef PUGIXML_WCHAR_MODE + PUGI__FN size_t get_valid_length(const char_t* data, size_t length) + { + if (length < 1) return 0; + + // discard last character if it's the lead of a surrogate pair + return (sizeof(wchar_t) == 2 && static_cast(static_cast(data[length - 1]) - 0xD800) < 0x400) ? length - 1 : length; + } + + PUGI__FN size_t convert_buffer_output(char_t* r_char, uint8_t* r_u8, uint16_t* r_u16, uint32_t* r_u32, const char_t* data, size_t length, xml_encoding encoding) + { + // only endian-swapping is required + if (need_endian_swap_utf(encoding, get_wchar_encoding())) + { + convert_wchar_endian_swap(r_char, data, length); + + return length * sizeof(char_t); + } + + // convert to utf8 + if (encoding == encoding_utf8) + return convert_buffer_output_generic(r_u8, data, length, wchar_decoder(), utf8_writer()); + + // convert to utf16 + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + return convert_buffer_output_generic(r_u16, data, length, wchar_decoder(), utf16_writer(), native_encoding != encoding); + } + + // convert to utf32 + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + return convert_buffer_output_generic(r_u32, data, length, wchar_decoder(), utf32_writer(), native_encoding != encoding); + } + + // convert to latin1 + if (encoding == encoding_latin1) + return convert_buffer_output_generic(r_u8, data, length, wchar_decoder(), latin1_writer()); + + assert(false && "Invalid encoding"); // unreachable + return 0; + } +#else + PUGI__FN size_t get_valid_length(const char_t* data, size_t length) + { + if (length < 5) return 0; + + for (size_t i = 1; i <= 4; ++i) + { + uint8_t ch = static_cast(data[length - i]); + + // either a standalone character or a leading one + if ((ch & 0xc0) != 0x80) return length - i; + } + + // there are four non-leading characters at the end, sequence tail is broken so might as well process the whole chunk + return length; + } + + PUGI__FN size_t convert_buffer_output(char_t* /* r_char */, uint8_t* r_u8, uint16_t* r_u16, uint32_t* r_u32, const char_t* data, size_t length, xml_encoding encoding) + { + if (encoding == encoding_utf16_be || encoding == encoding_utf16_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf16_le : encoding_utf16_be; + + return convert_buffer_output_generic(r_u16, data, length, utf8_decoder(), utf16_writer(), native_encoding != encoding); + } + + if (encoding == encoding_utf32_be || encoding == encoding_utf32_le) + { + xml_encoding native_encoding = is_little_endian() ? encoding_utf32_le : encoding_utf32_be; + + return convert_buffer_output_generic(r_u32, data, length, utf8_decoder(), utf32_writer(), native_encoding != encoding); + } + + if (encoding == encoding_latin1) + return convert_buffer_output_generic(r_u8, data, length, utf8_decoder(), latin1_writer()); + + assert(false && "Invalid encoding"); // unreachable + return 0; + } +#endif + + class xml_buffered_writer + { + xml_buffered_writer(const xml_buffered_writer&); + xml_buffered_writer& operator=(const xml_buffered_writer&); + + public: + xml_buffered_writer(xml_writer& writer_, xml_encoding user_encoding): writer(writer_), bufsize(0), encoding(get_write_encoding(user_encoding)) + { + PUGI__STATIC_ASSERT(bufcapacity >= 8); + } + + size_t flush() + { + flush(buffer, bufsize); + bufsize = 0; + return 0; + } + + void flush(const char_t* data, size_t size) + { + if (size == 0) return; + + // fast path, just write data + if (encoding == get_write_native_encoding()) + writer.write(data, size * sizeof(char_t)); + else + { + // convert chunk + size_t result = convert_buffer_output(scratch.data_char, scratch.data_u8, scratch.data_u16, scratch.data_u32, data, size, encoding); + assert(result <= sizeof(scratch)); + + // write data + writer.write(scratch.data_u8, result); + } + } + + void write_direct(const char_t* data, size_t length) + { + // flush the remaining buffer contents + flush(); + + // handle large chunks + if (length > bufcapacity) + { + if (encoding == get_write_native_encoding()) + { + // fast path, can just write data chunk + writer.write(data, length * sizeof(char_t)); + return; + } + + // need to convert in suitable chunks + while (length > bufcapacity) + { + // get chunk size by selecting such number of characters that are guaranteed to fit into scratch buffer + // and form a complete codepoint sequence (i.e. discard start of last codepoint if necessary) + size_t chunk_size = get_valid_length(data, bufcapacity); + assert(chunk_size); + + // convert chunk and write + flush(data, chunk_size); + + // iterate + data += chunk_size; + length -= chunk_size; + } + + // small tail is copied below + bufsize = 0; + } + + memcpy(buffer + bufsize, data, length * sizeof(char_t)); + bufsize += length; + } + + void write_buffer(const char_t* data, size_t length) + { + size_t offset = bufsize; + + if (offset + length <= bufcapacity) + { + memcpy(buffer + offset, data, length * sizeof(char_t)); + bufsize = offset + length; + } + else + { + write_direct(data, length); + } + } + + void write_string(const char_t* data) + { + // write the part of the string that fits in the buffer + size_t offset = bufsize; + + while (*data && offset < bufcapacity) + buffer[offset++] = *data++; + + // write the rest + if (offset < bufcapacity) + { + bufsize = offset; + } + else + { + // backtrack a bit if we have split the codepoint + size_t length = offset - bufsize; + size_t extra = length - get_valid_length(data - length, length); + + bufsize = offset - extra; + + write_direct(data - extra, strlength(data) + extra); + } + } + + void write(char_t d0) + { + size_t offset = bufsize; + if (offset > bufcapacity - 1) offset = flush(); + + buffer[offset + 0] = d0; + bufsize = offset + 1; + } + + void write(char_t d0, char_t d1) + { + size_t offset = bufsize; + if (offset > bufcapacity - 2) offset = flush(); + + buffer[offset + 0] = d0; + buffer[offset + 1] = d1; + bufsize = offset + 2; + } + + void write(char_t d0, char_t d1, char_t d2) + { + size_t offset = bufsize; + if (offset > bufcapacity - 3) offset = flush(); + + buffer[offset + 0] = d0; + buffer[offset + 1] = d1; + buffer[offset + 2] = d2; + bufsize = offset + 3; + } + + void write(char_t d0, char_t d1, char_t d2, char_t d3) + { + size_t offset = bufsize; + if (offset > bufcapacity - 4) offset = flush(); + + buffer[offset + 0] = d0; + buffer[offset + 1] = d1; + buffer[offset + 2] = d2; + buffer[offset + 3] = d3; + bufsize = offset + 4; + } + + void write(char_t d0, char_t d1, char_t d2, char_t d3, char_t d4) + { + size_t offset = bufsize; + if (offset > bufcapacity - 5) offset = flush(); + + buffer[offset + 0] = d0; + buffer[offset + 1] = d1; + buffer[offset + 2] = d2; + buffer[offset + 3] = d3; + buffer[offset + 4] = d4; + bufsize = offset + 5; + } + + void write(char_t d0, char_t d1, char_t d2, char_t d3, char_t d4, char_t d5) + { + size_t offset = bufsize; + if (offset > bufcapacity - 6) offset = flush(); + + buffer[offset + 0] = d0; + buffer[offset + 1] = d1; + buffer[offset + 2] = d2; + buffer[offset + 3] = d3; + buffer[offset + 4] = d4; + buffer[offset + 5] = d5; + bufsize = offset + 6; + } + + // utf8 maximum expansion: x4 (-> utf32) + // utf16 maximum expansion: x2 (-> utf32) + // utf32 maximum expansion: x1 + enum + { + bufcapacitybytes = + #ifdef PUGIXML_MEMORY_OUTPUT_STACK + PUGIXML_MEMORY_OUTPUT_STACK + #else + 10240 + #endif + , + bufcapacity = bufcapacitybytes / (sizeof(char_t) + 4) + }; + + char_t buffer[bufcapacity]; + + union + { + uint8_t data_u8[4 * bufcapacity]; + uint16_t data_u16[2 * bufcapacity]; + uint32_t data_u32[bufcapacity]; + char_t data_char[bufcapacity]; + } scratch; + + xml_writer& writer; + size_t bufsize; + xml_encoding encoding; + }; + + PUGI__FN void text_output_escaped(xml_buffered_writer& writer, const char_t* s, chartypex_t type, unsigned int flags) + { + while (*s) + { + const char_t* prev = s; + + // While *s is a usual symbol + PUGI__SCANWHILE_UNROLL(!PUGI__IS_CHARTYPEX(ss, type)); + + writer.write_buffer(prev, static_cast(s - prev)); + + switch (*s) + { + case 0: break; + case '&': + writer.write('&', 'a', 'm', 'p', ';'); + ++s; + break; + case '<': + writer.write('&', 'l', 't', ';'); + ++s; + break; + case '>': + writer.write('&', 'g', 't', ';'); + ++s; + break; + case '"': + if (flags & format_attribute_single_quote) + writer.write('"'); + else + writer.write('&', 'q', 'u', 'o', 't', ';'); + ++s; + break; + case '\'': + if (flags & format_attribute_single_quote) + writer.write('&', 'a', 'p', 'o', 's', ';'); + else + writer.write('\''); + ++s; + break; + default: // s is not a usual symbol + { + unsigned int ch = static_cast(*s++); + assert(ch < 32); + + if (!(flags & format_skip_control_chars)) + writer.write('&', '#', static_cast((ch / 10) + '0'), static_cast((ch % 10) + '0'), ';'); + } + } + } + } + + PUGI__FN void text_output(xml_buffered_writer& writer, const char_t* s, chartypex_t type, unsigned int flags) + { + if (flags & format_no_escapes) + writer.write_string(s); + else + text_output_escaped(writer, s, type, flags); + } + + PUGI__FN void text_output_cdata(xml_buffered_writer& writer, const char_t* s) + { + do + { + writer.write('<', '!', '[', 'C', 'D'); + writer.write('A', 'T', 'A', '['); + + const char_t* prev = s; + + // look for ]]> sequence - we can't output it as is since it terminates CDATA + while (*s && !(s[0] == ']' && s[1] == ']' && s[2] == '>')) ++s; + + // skip ]] if we stopped at ]]>, > will go to the next CDATA section + if (*s) s += 2; + + writer.write_buffer(prev, static_cast(s - prev)); + + writer.write(']', ']', '>'); + } + while (*s); + } + + PUGI__FN void text_output_indent(xml_buffered_writer& writer, const char_t* indent, size_t indent_length, unsigned int depth) + { + switch (indent_length) + { + case 1: + { + for (unsigned int i = 0; i < depth; ++i) + writer.write(indent[0]); + break; + } + + case 2: + { + for (unsigned int i = 0; i < depth; ++i) + writer.write(indent[0], indent[1]); + break; + } + + case 3: + { + for (unsigned int i = 0; i < depth; ++i) + writer.write(indent[0], indent[1], indent[2]); + break; + } + + case 4: + { + for (unsigned int i = 0; i < depth; ++i) + writer.write(indent[0], indent[1], indent[2], indent[3]); + break; + } + + default: + { + for (unsigned int i = 0; i < depth; ++i) + writer.write_buffer(indent, indent_length); + } + } + } + + PUGI__FN void node_output_comment(xml_buffered_writer& writer, const char_t* s) + { + writer.write('<', '!', '-', '-'); + + while (*s) + { + const char_t* prev = s; + + // look for -\0 or -- sequence - we can't output it since -- is illegal in comment body + while (*s && !(s[0] == '-' && (s[1] == '-' || s[1] == 0))) ++s; + + writer.write_buffer(prev, static_cast(s - prev)); + + if (*s) + { + assert(*s == '-'); + + writer.write('-', ' '); + ++s; + } + } + + writer.write('-', '-', '>'); + } + + PUGI__FN void node_output_pi_value(xml_buffered_writer& writer, const char_t* s) + { + while (*s) + { + const char_t* prev = s; + + // look for ?> sequence - we can't output it since ?> terminates PI + while (*s && !(s[0] == '?' && s[1] == '>')) ++s; + + writer.write_buffer(prev, static_cast(s - prev)); + + if (*s) + { + assert(s[0] == '?' && s[1] == '>'); + + writer.write('?', ' ', '>'); + s += 2; + } + } + } + + PUGI__FN void node_output_attributes(xml_buffered_writer& writer, xml_node_struct* node, const char_t* indent, size_t indent_length, unsigned int flags, unsigned int depth) + { + const char_t* default_name = PUGIXML_TEXT(":anonymous"); + const char_t enquotation_char = (flags & format_attribute_single_quote) ? '\'' : '"'; + + for (xml_attribute_struct* a = node->first_attribute; a; a = a->next_attribute) + { + if ((flags & (format_indent_attributes | format_raw)) == format_indent_attributes) + { + writer.write('\n'); + + text_output_indent(writer, indent, indent_length, depth + 1); + } + else + { + writer.write(' '); + } + + writer.write_string(a->name ? a->name + 0 : default_name); + writer.write('=', enquotation_char); + + if (a->value) + text_output(writer, a->value, ctx_special_attr, flags); + + writer.write(enquotation_char); + } + } + + PUGI__FN bool node_output_start(xml_buffered_writer& writer, xml_node_struct* node, const char_t* indent, size_t indent_length, unsigned int flags, unsigned int depth) + { + const char_t* default_name = PUGIXML_TEXT(":anonymous"); + const char_t* name = node->name ? node->name + 0 : default_name; + + writer.write('<'); + writer.write_string(name); + + if (node->first_attribute) + node_output_attributes(writer, node, indent, indent_length, flags, depth); + + // element nodes can have value if parse_embed_pcdata was used + if (!node->value) + { + if (!node->first_child) + { + if (flags & format_no_empty_element_tags) + { + writer.write('>', '<', '/'); + writer.write_string(name); + writer.write('>'); + + return false; + } + else + { + if ((flags & format_raw) == 0) + writer.write(' '); + + writer.write('/', '>'); + + return false; + } + } + else + { + writer.write('>'); + + return true; + } + } + else + { + writer.write('>'); + + text_output(writer, node->value, ctx_special_pcdata, flags); + + if (!node->first_child) + { + writer.write('<', '/'); + writer.write_string(name); + writer.write('>'); + + return false; + } + else + { + return true; + } + } + } + + PUGI__FN void node_output_end(xml_buffered_writer& writer, xml_node_struct* node) + { + const char_t* default_name = PUGIXML_TEXT(":anonymous"); + const char_t* name = node->name ? node->name + 0 : default_name; + + writer.write('<', '/'); + writer.write_string(name); + writer.write('>'); + } + + PUGI__FN void node_output_simple(xml_buffered_writer& writer, xml_node_struct* node, unsigned int flags) + { + const char_t* default_name = PUGIXML_TEXT(":anonymous"); + + switch (PUGI__NODETYPE(node)) + { + case node_pcdata: + text_output(writer, node->value ? node->value + 0 : PUGIXML_TEXT(""), ctx_special_pcdata, flags); + break; + + case node_cdata: + text_output_cdata(writer, node->value ? node->value + 0 : PUGIXML_TEXT("")); + break; + + case node_comment: + node_output_comment(writer, node->value ? node->value + 0 : PUGIXML_TEXT("")); + break; + + case node_pi: + writer.write('<', '?'); + writer.write_string(node->name ? node->name + 0 : default_name); + + if (node->value) + { + writer.write(' '); + node_output_pi_value(writer, node->value); + } + + writer.write('?', '>'); + break; + + case node_declaration: + writer.write('<', '?'); + writer.write_string(node->name ? node->name + 0 : default_name); + node_output_attributes(writer, node, PUGIXML_TEXT(""), 0, flags | format_raw, 0); + writer.write('?', '>'); + break; + + case node_doctype: + writer.write('<', '!', 'D', 'O', 'C'); + writer.write('T', 'Y', 'P', 'E'); + + if (node->value) + { + writer.write(' '); + writer.write_string(node->value); + } + + writer.write('>'); + break; + + default: + assert(false && "Invalid node type"); // unreachable + } + } + + enum indent_flags_t + { + indent_newline = 1, + indent_indent = 2 + }; + + PUGI__FN void node_output(xml_buffered_writer& writer, xml_node_struct* root, const char_t* indent, unsigned int flags, unsigned int depth) + { + size_t indent_length = ((flags & (format_indent | format_indent_attributes)) && (flags & format_raw) == 0) ? strlength(indent) : 0; + unsigned int indent_flags = indent_indent; + + xml_node_struct* node = root; + + do + { + assert(node); + + // begin writing current node + if (PUGI__NODETYPE(node) == node_pcdata || PUGI__NODETYPE(node) == node_cdata) + { + node_output_simple(writer, node, flags); + + indent_flags = 0; + } + else + { + if ((indent_flags & indent_newline) && (flags & format_raw) == 0) + writer.write('\n'); + + if ((indent_flags & indent_indent) && indent_length) + text_output_indent(writer, indent, indent_length, depth); + + if (PUGI__NODETYPE(node) == node_element) + { + indent_flags = indent_newline | indent_indent; + + if (node_output_start(writer, node, indent, indent_length, flags, depth)) + { + // element nodes can have value if parse_embed_pcdata was used + if (node->value) + indent_flags = 0; + + node = node->first_child; + depth++; + continue; + } + } + else if (PUGI__NODETYPE(node) == node_document) + { + indent_flags = indent_indent; + + if (node->first_child) + { + node = node->first_child; + continue; + } + } + else + { + node_output_simple(writer, node, flags); + + indent_flags = indent_newline | indent_indent; + } + } + + // continue to the next node + while (node != root) + { + if (node->next_sibling) + { + node = node->next_sibling; + break; + } + + node = node->parent; + + // write closing node + if (PUGI__NODETYPE(node) == node_element) + { + depth--; + + if ((indent_flags & indent_newline) && (flags & format_raw) == 0) + writer.write('\n'); + + if ((indent_flags & indent_indent) && indent_length) + text_output_indent(writer, indent, indent_length, depth); + + node_output_end(writer, node); + + indent_flags = indent_newline | indent_indent; + } + } + } + while (node != root); + + if ((indent_flags & indent_newline) && (flags & format_raw) == 0) + writer.write('\n'); + } + + PUGI__FN bool has_declaration(xml_node_struct* node) + { + for (xml_node_struct* child = node->first_child; child; child = child->next_sibling) + { + xml_node_type type = PUGI__NODETYPE(child); + + if (type == node_declaration) return true; + if (type == node_element) return false; + } + + return false; + } + + PUGI__FN bool is_attribute_of(xml_attribute_struct* attr, xml_node_struct* node) + { + for (xml_attribute_struct* a = node->first_attribute; a; a = a->next_attribute) + if (a == attr) + return true; + + return false; + } + + PUGI__FN bool allow_insert_attribute(xml_node_type parent) + { + return parent == node_element || parent == node_declaration; + } + + PUGI__FN bool allow_insert_child(xml_node_type parent, xml_node_type child) + { + if (parent != node_document && parent != node_element) return false; + if (child == node_document || child == node_null) return false; + if (parent != node_document && (child == node_declaration || child == node_doctype)) return false; + + return true; + } + + PUGI__FN bool allow_move(xml_node parent, xml_node child) + { + // check that child can be a child of parent + if (!allow_insert_child(parent.type(), child.type())) + return false; + + // check that node is not moved between documents + if (parent.root() != child.root()) + return false; + + // check that new parent is not in the child subtree + xml_node cur = parent; + + while (cur) + { + if (cur == child) + return false; + + cur = cur.parent(); + } + + return true; + } + + template + PUGI__FN void node_copy_string(String& dest, Header& header, uintptr_t header_mask, char_t* source, Header& source_header, xml_allocator* alloc) + { + assert(!dest && (header & header_mask) == 0); + + if (source) + { + if (alloc && (source_header & header_mask) == 0) + { + dest = source; + + // since strcpy_insitu can reuse document buffer memory we need to mark both source and dest as shared + header |= xml_memory_page_contents_shared_mask; + source_header |= xml_memory_page_contents_shared_mask; + } + else + strcpy_insitu(dest, header, header_mask, source, strlength(source)); + } + } + + PUGI__FN void node_copy_contents(xml_node_struct* dn, xml_node_struct* sn, xml_allocator* shared_alloc) + { + node_copy_string(dn->name, dn->header, xml_memory_page_name_allocated_mask, sn->name, sn->header, shared_alloc); + node_copy_string(dn->value, dn->header, xml_memory_page_value_allocated_mask, sn->value, sn->header, shared_alloc); + + for (xml_attribute_struct* sa = sn->first_attribute; sa; sa = sa->next_attribute) + { + xml_attribute_struct* da = append_new_attribute(dn, get_allocator(dn)); + + if (da) + { + node_copy_string(da->name, da->header, xml_memory_page_name_allocated_mask, sa->name, sa->header, shared_alloc); + node_copy_string(da->value, da->header, xml_memory_page_value_allocated_mask, sa->value, sa->header, shared_alloc); + } + } + } + + PUGI__FN void node_copy_tree(xml_node_struct* dn, xml_node_struct* sn) + { + xml_allocator& alloc = get_allocator(dn); + xml_allocator* shared_alloc = (&alloc == &get_allocator(sn)) ? &alloc : 0; + + node_copy_contents(dn, sn, shared_alloc); + + xml_node_struct* dit = dn; + xml_node_struct* sit = sn->first_child; + + while (sit && sit != sn) + { + // loop invariant: dit is inside the subtree rooted at dn + assert(dit); + + // when a tree is copied into one of the descendants, we need to skip that subtree to avoid an infinite loop + if (sit != dn) + { + xml_node_struct* copy = append_new_node(dit, alloc, PUGI__NODETYPE(sit)); + + if (copy) + { + node_copy_contents(copy, sit, shared_alloc); + + if (sit->first_child) + { + dit = copy; + sit = sit->first_child; + continue; + } + } + } + + // continue to the next node + do + { + if (sit->next_sibling) + { + sit = sit->next_sibling; + break; + } + + sit = sit->parent; + dit = dit->parent; + + // loop invariant: dit is inside the subtree rooted at dn while sit is inside sn + assert(sit == sn || dit); + } + while (sit != sn); + } + + assert(!sit || dit == dn->parent); + } + + PUGI__FN void node_copy_attribute(xml_attribute_struct* da, xml_attribute_struct* sa) + { + xml_allocator& alloc = get_allocator(da); + xml_allocator* shared_alloc = (&alloc == &get_allocator(sa)) ? &alloc : 0; + + node_copy_string(da->name, da->header, xml_memory_page_name_allocated_mask, sa->name, sa->header, shared_alloc); + node_copy_string(da->value, da->header, xml_memory_page_value_allocated_mask, sa->value, sa->header, shared_alloc); + } + + inline bool is_text_node(xml_node_struct* node) + { + xml_node_type type = PUGI__NODETYPE(node); + + return type == node_pcdata || type == node_cdata; + } + + // get value with conversion functions + template PUGI__FN PUGI__UNSIGNED_OVERFLOW U string_to_integer(const char_t* value, U minv, U maxv) + { + U result = 0; + const char_t* s = value; + + while (PUGI__IS_CHARTYPE(*s, ct_space)) + s++; + + bool negative = (*s == '-'); + + s += (*s == '+' || *s == '-'); + + bool overflow = false; + + if (s[0] == '0' && (s[1] | ' ') == 'x') + { + s += 2; + + // since overflow detection relies on length of the sequence skip leading zeros + while (*s == '0') + s++; + + const char_t* start = s; + + for (;;) + { + if (static_cast(*s - '0') < 10) + result = result * 16 + (*s - '0'); + else if (static_cast((*s | ' ') - 'a') < 6) + result = result * 16 + ((*s | ' ') - 'a' + 10); + else + break; + + s++; + } + + size_t digits = static_cast(s - start); + + overflow = digits > sizeof(U) * 2; + } + else + { + // since overflow detection relies on length of the sequence skip leading zeros + while (*s == '0') + s++; + + const char_t* start = s; + + for (;;) + { + if (static_cast(*s - '0') < 10) + result = result * 10 + (*s - '0'); + else + break; + + s++; + } + + size_t digits = static_cast(s - start); + + PUGI__STATIC_ASSERT(sizeof(U) == 8 || sizeof(U) == 4 || sizeof(U) == 2); + + const size_t max_digits10 = sizeof(U) == 8 ? 20 : sizeof(U) == 4 ? 10 : 5; + const char_t max_lead = sizeof(U) == 8 ? '1' : sizeof(U) == 4 ? '4' : '6'; + const size_t high_bit = sizeof(U) * 8 - 1; + + overflow = digits >= max_digits10 && !(digits == max_digits10 && (*start < max_lead || (*start == max_lead && result >> high_bit))); + } + + if (negative) + { + // Workaround for crayc++ CC-3059: Expected no overflow in routine. + #ifdef _CRAYC + return (overflow || result > ~minv + 1) ? minv : ~result + 1; + #else + return (overflow || result > 0 - minv) ? minv : 0 - result; + #endif + } + else + return (overflow || result > maxv) ? maxv : result; + } + + PUGI__FN int get_value_int(const char_t* value) + { + return string_to_integer(value, static_cast(INT_MIN), INT_MAX); + } + + PUGI__FN unsigned int get_value_uint(const char_t* value) + { + return string_to_integer(value, 0, UINT_MAX); + } + + PUGI__FN double get_value_double(const char_t* value) + { + #ifdef PUGIXML_WCHAR_MODE + return wcstod(value, 0); + #else + return strtod(value, 0); + #endif + } + + PUGI__FN float get_value_float(const char_t* value) + { + #ifdef PUGIXML_WCHAR_MODE + return static_cast(wcstod(value, 0)); + #else + return static_cast(strtod(value, 0)); + #endif + } + + PUGI__FN bool get_value_bool(const char_t* value) + { + // only look at first char + char_t first = *value; + + // 1*, t* (true), T* (True), y* (yes), Y* (YES) + return (first == '1' || first == 't' || first == 'T' || first == 'y' || first == 'Y'); + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN long long get_value_llong(const char_t* value) + { + return string_to_integer(value, static_cast(LLONG_MIN), LLONG_MAX); + } + + PUGI__FN unsigned long long get_value_ullong(const char_t* value) + { + return string_to_integer(value, 0, ULLONG_MAX); + } +#endif + + template PUGI__FN PUGI__UNSIGNED_OVERFLOW char_t* integer_to_string(char_t* begin, char_t* end, U value, bool negative) + { + char_t* result = end - 1; + U rest = negative ? 0 - value : value; + + do + { + *result-- = static_cast('0' + (rest % 10)); + rest /= 10; + } + while (rest); + + assert(result >= begin); + (void)begin; + + *result = '-'; + + return result + !negative; + } + + // set value with conversion functions + template + PUGI__FN bool set_value_ascii(String& dest, Header& header, uintptr_t header_mask, char* buf) + { + #ifdef PUGIXML_WCHAR_MODE + char_t wbuf[128]; + assert(strlen(buf) < sizeof(wbuf) / sizeof(wbuf[0])); + + size_t offset = 0; + for (; buf[offset]; ++offset) wbuf[offset] = buf[offset]; + + return strcpy_insitu(dest, header, header_mask, wbuf, offset); + #else + return strcpy_insitu(dest, header, header_mask, buf, strlen(buf)); + #endif + } + + template + PUGI__FN bool set_value_integer(String& dest, Header& header, uintptr_t header_mask, U value, bool negative) + { + char_t buf[64]; + char_t* end = buf + sizeof(buf) / sizeof(buf[0]); + char_t* begin = integer_to_string(buf, end, value, negative); + + return strcpy_insitu(dest, header, header_mask, begin, end - begin); + } + + template + PUGI__FN bool set_value_convert(String& dest, Header& header, uintptr_t header_mask, float value, int precision) + { + char buf[128]; + PUGI__SNPRINTF(buf, "%.*g", precision, double(value)); + + return set_value_ascii(dest, header, header_mask, buf); + } + + template + PUGI__FN bool set_value_convert(String& dest, Header& header, uintptr_t header_mask, double value, int precision) + { + char buf[128]; + PUGI__SNPRINTF(buf, "%.*g", precision, value); + + return set_value_ascii(dest, header, header_mask, buf); + } + + template + PUGI__FN bool set_value_bool(String& dest, Header& header, uintptr_t header_mask, bool value) + { + return strcpy_insitu(dest, header, header_mask, value ? PUGIXML_TEXT("true") : PUGIXML_TEXT("false"), value ? 4 : 5); + } + + PUGI__FN xml_parse_result load_buffer_impl(xml_document_struct* doc, xml_node_struct* root, void* contents, size_t size, unsigned int options, xml_encoding encoding, bool is_mutable, bool own, char_t** out_buffer) + { + // check input buffer + if (!contents && size) return make_parse_result(status_io_error); + + // get actual encoding + xml_encoding buffer_encoding = impl::get_buffer_encoding(encoding, contents, size); + + // get private buffer + char_t* buffer = 0; + size_t length = 0; + + // coverity[var_deref_model] + if (!impl::convert_buffer(buffer, length, buffer_encoding, contents, size, is_mutable)) return impl::make_parse_result(status_out_of_memory); + + // delete original buffer if we performed a conversion + if (own && buffer != contents && contents) impl::xml_memory::deallocate(contents); + + // grab onto buffer if it's our buffer, user is responsible for deallocating contents himself + if (own || buffer != contents) *out_buffer = buffer; + + // store buffer for offset_debug + doc->buffer = buffer; + + // parse + xml_parse_result res = impl::xml_parser::parse(buffer, length, doc, root, options); + + // remember encoding + res.encoding = buffer_encoding; + + return res; + } + + // we need to get length of entire file to load it in memory; the only (relatively) sane way to do it is via seek/tell trick + PUGI__FN xml_parse_status get_file_size(FILE* file, size_t& out_result) + { + #if defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 && !defined(_WIN32_WCE) + // there are 64-bit versions of fseek/ftell, let's use them + typedef __int64 length_type; + + _fseeki64(file, 0, SEEK_END); + length_type length = _ftelli64(file); + _fseeki64(file, 0, SEEK_SET); + #elif defined(__MINGW32__) && !defined(__NO_MINGW_LFS) && (!defined(__STRICT_ANSI__) || defined(__MINGW64_VERSION_MAJOR)) + // there are 64-bit versions of fseek/ftell, let's use them + typedef off64_t length_type; + + fseeko64(file, 0, SEEK_END); + length_type length = ftello64(file); + fseeko64(file, 0, SEEK_SET); + #else + // if this is a 32-bit OS, long is enough; if this is a unix system, long is 64-bit, which is enough; otherwise we can't do anything anyway. + typedef long length_type; + + fseek(file, 0, SEEK_END); + length_type length = ftell(file); + fseek(file, 0, SEEK_SET); + #endif + + // check for I/O errors + if (length < 0) return status_io_error; + + // check for overflow + size_t result = static_cast(length); + + if (static_cast(result) != length) return status_out_of_memory; + + // finalize + out_result = result; + + return status_ok; + } + + // This function assumes that buffer has extra sizeof(char_t) writable bytes after size + PUGI__FN size_t zero_terminate_buffer(void* buffer, size_t size, xml_encoding encoding) + { + // We only need to zero-terminate if encoding conversion does not do it for us + #ifdef PUGIXML_WCHAR_MODE + xml_encoding wchar_encoding = get_wchar_encoding(); + + if (encoding == wchar_encoding || need_endian_swap_utf(encoding, wchar_encoding)) + { + size_t length = size / sizeof(char_t); + + static_cast(buffer)[length] = 0; + return (length + 1) * sizeof(char_t); + } + #else + if (encoding == encoding_utf8) + { + static_cast(buffer)[size] = 0; + return size + 1; + } + #endif + + return size; + } + + PUGI__FN xml_parse_result load_file_impl(xml_document_struct* doc, FILE* file, unsigned int options, xml_encoding encoding, char_t** out_buffer) + { + if (!file) return make_parse_result(status_file_not_found); + + // get file size (can result in I/O errors) + size_t size = 0; + xml_parse_status size_status = get_file_size(file, size); + if (size_status != status_ok) return make_parse_result(size_status); + + size_t max_suffix_size = sizeof(char_t); + + // allocate buffer for the whole file + char* contents = static_cast(xml_memory::allocate(size + max_suffix_size)); + if (!contents) return make_parse_result(status_out_of_memory); + + // read file in memory + size_t read_size = fread(contents, 1, size, file); + + if (read_size != size) + { + xml_memory::deallocate(contents); + return make_parse_result(status_io_error); + } + + xml_encoding real_encoding = get_buffer_encoding(encoding, contents, size); + + return load_buffer_impl(doc, doc, contents, zero_terminate_buffer(contents, size, real_encoding), options, real_encoding, true, true, out_buffer); + } + + PUGI__FN void close_file(FILE* file) + { + fclose(file); + } + +#ifndef PUGIXML_NO_STL + template struct xml_stream_chunk + { + static xml_stream_chunk* create() + { + void* memory = xml_memory::allocate(sizeof(xml_stream_chunk)); + if (!memory) return 0; + + return new (memory) xml_stream_chunk(); + } + + static void destroy(xml_stream_chunk* chunk) + { + // free chunk chain + while (chunk) + { + xml_stream_chunk* next_ = chunk->next; + + xml_memory::deallocate(chunk); + + chunk = next_; + } + } + + xml_stream_chunk(): next(0), size(0) + { + } + + xml_stream_chunk* next; + size_t size; + + T data[xml_memory_page_size / sizeof(T)]; + }; + + template PUGI__FN xml_parse_status load_stream_data_noseek(std::basic_istream& stream, void** out_buffer, size_t* out_size) + { + auto_deleter > chunks(0, xml_stream_chunk::destroy); + + // read file to a chunk list + size_t total = 0; + xml_stream_chunk* last = 0; + + while (!stream.eof()) + { + // allocate new chunk + xml_stream_chunk* chunk = xml_stream_chunk::create(); + if (!chunk) return status_out_of_memory; + + // append chunk to list + if (last) last = last->next = chunk; + else chunks.data = last = chunk; + + // read data to chunk + stream.read(chunk->data, static_cast(sizeof(chunk->data) / sizeof(T))); + chunk->size = static_cast(stream.gcount()) * sizeof(T); + + // read may set failbit | eofbit in case gcount() is less than read length, so check for other I/O errors + if (stream.bad() || (!stream.eof() && stream.fail())) return status_io_error; + + // guard against huge files (chunk size is small enough to make this overflow check work) + if (total + chunk->size < total) return status_out_of_memory; + total += chunk->size; + } + + size_t max_suffix_size = sizeof(char_t); + + // copy chunk list to a contiguous buffer + char* buffer = static_cast(xml_memory::allocate(total + max_suffix_size)); + if (!buffer) return status_out_of_memory; + + char* write = buffer; + + for (xml_stream_chunk* chunk = chunks.data; chunk; chunk = chunk->next) + { + assert(write + chunk->size <= buffer + total); + memcpy(write, chunk->data, chunk->size); + write += chunk->size; + } + + assert(write == buffer + total); + + // return buffer + *out_buffer = buffer; + *out_size = total; + + return status_ok; + } + + template PUGI__FN xml_parse_status load_stream_data_seek(std::basic_istream& stream, void** out_buffer, size_t* out_size) + { + // get length of remaining data in stream + typename std::basic_istream::pos_type pos = stream.tellg(); + stream.seekg(0, std::ios::end); + std::streamoff length = stream.tellg() - pos; + stream.seekg(pos); + + if (stream.fail() || pos < 0) return status_io_error; + + // guard against huge files + size_t read_length = static_cast(length); + + if (static_cast(read_length) != length || length < 0) return status_out_of_memory; + + size_t max_suffix_size = sizeof(char_t); + + // read stream data into memory (guard against stream exceptions with buffer holder) + auto_deleter buffer(xml_memory::allocate(read_length * sizeof(T) + max_suffix_size), xml_memory::deallocate); + if (!buffer.data) return status_out_of_memory; + + stream.read(static_cast(buffer.data), static_cast(read_length)); + + // read may set failbit | eofbit in case gcount() is less than read_length (i.e. line ending conversion), so check for other I/O errors + if (stream.bad() || (!stream.eof() && stream.fail())) return status_io_error; + + // return buffer + size_t actual_length = static_cast(stream.gcount()); + assert(actual_length <= read_length); + + *out_buffer = buffer.release(); + *out_size = actual_length * sizeof(T); + + return status_ok; + } + + template PUGI__FN xml_parse_result load_stream_impl(xml_document_struct* doc, std::basic_istream& stream, unsigned int options, xml_encoding encoding, char_t** out_buffer) + { + void* buffer = 0; + size_t size = 0; + xml_parse_status status = status_ok; + + // if stream has an error bit set, bail out (otherwise tellg() can fail and we'll clear error bits) + if (stream.fail()) return make_parse_result(status_io_error); + + // load stream to memory (using seek-based implementation if possible, since it's faster and takes less memory) + if (stream.tellg() < 0) + { + stream.clear(); // clear error flags that could be set by a failing tellg + status = load_stream_data_noseek(stream, &buffer, &size); + } + else + status = load_stream_data_seek(stream, &buffer, &size); + + if (status != status_ok) return make_parse_result(status); + + xml_encoding real_encoding = get_buffer_encoding(encoding, buffer, size); + + return load_buffer_impl(doc, doc, buffer, zero_terminate_buffer(buffer, size, real_encoding), options, real_encoding, true, true, out_buffer); + } +#endif + +#if defined(PUGI__MSVC_CRT_VERSION) || defined(__BORLANDC__) || (defined(__MINGW32__) && (!defined(__STRICT_ANSI__) || defined(__MINGW64_VERSION_MAJOR))) + PUGI__FN FILE* open_file_wide(const wchar_t* path, const wchar_t* mode) + { +#if defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 + FILE* file = 0; + return _wfopen_s(&file, path, mode) == 0 ? file : 0; +#else + return _wfopen(path, mode); +#endif + } +#else + PUGI__FN char* convert_path_heap(const wchar_t* str) + { + assert(str); + + // first pass: get length in utf8 characters + size_t length = strlength_wide(str); + size_t size = as_utf8_begin(str, length); + + // allocate resulting string + char* result = static_cast(xml_memory::allocate(size + 1)); + if (!result) return 0; + + // second pass: convert to utf8 + as_utf8_end(result, size, str, length); + + // zero-terminate + result[size] = 0; + + return result; + } + + PUGI__FN FILE* open_file_wide(const wchar_t* path, const wchar_t* mode) + { + // there is no standard function to open wide paths, so our best bet is to try utf8 path + char* path_utf8 = convert_path_heap(path); + if (!path_utf8) return 0; + + // convert mode to ASCII (we mirror _wfopen interface) + char mode_ascii[4] = {0}; + for (size_t i = 0; mode[i]; ++i) mode_ascii[i] = static_cast(mode[i]); + + // try to open the utf8 path + FILE* result = fopen(path_utf8, mode_ascii); + + // free dummy buffer + xml_memory::deallocate(path_utf8); + + return result; + } +#endif + + PUGI__FN FILE* open_file(const char* path, const char* mode) + { +#if defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 + FILE* file = 0; + return fopen_s(&file, path, mode) == 0 ? file : 0; +#else + return fopen(path, mode); +#endif + } + + PUGI__FN bool save_file_impl(const xml_document& doc, FILE* file, const char_t* indent, unsigned int flags, xml_encoding encoding) + { + if (!file) return false; + + xml_writer_file writer(file); + doc.save(writer, indent, flags, encoding); + + return ferror(file) == 0; + } + + struct name_null_sentry + { + xml_node_struct* node; + char_t* name; + + name_null_sentry(xml_node_struct* node_): node(node_), name(node_->name) + { + node->name = 0; + } + + ~name_null_sentry() + { + node->name = name; + } + }; +PUGI__NS_END + +namespace pugi +{ + PUGI__FN xml_writer_file::xml_writer_file(void* file_): file(file_) + { + } + + PUGI__FN void xml_writer_file::write(const void* data, size_t size) + { + size_t result = fwrite(data, 1, size, static_cast(file)); + (void)!result; // unfortunately we can't do proper error handling here + } + +#ifndef PUGIXML_NO_STL + PUGI__FN xml_writer_stream::xml_writer_stream(std::basic_ostream >& stream): narrow_stream(&stream), wide_stream(0) + { + } + + PUGI__FN xml_writer_stream::xml_writer_stream(std::basic_ostream >& stream): narrow_stream(0), wide_stream(&stream) + { + } + + PUGI__FN void xml_writer_stream::write(const void* data, size_t size) + { + if (narrow_stream) + { + assert(!wide_stream); + narrow_stream->write(reinterpret_cast(data), static_cast(size)); + } + else + { + assert(wide_stream); + assert(size % sizeof(wchar_t) == 0); + + wide_stream->write(reinterpret_cast(data), static_cast(size / sizeof(wchar_t))); + } + } +#endif + + PUGI__FN xml_tree_walker::xml_tree_walker(): _depth(0) + { + } + + PUGI__FN xml_tree_walker::~xml_tree_walker() + { + } + + PUGI__FN int xml_tree_walker::depth() const + { + return _depth; + } + + PUGI__FN bool xml_tree_walker::begin(xml_node&) + { + return true; + } + + PUGI__FN bool xml_tree_walker::end(xml_node&) + { + return true; + } + + PUGI__FN xml_attribute::xml_attribute(): _attr(0) + { + } + + PUGI__FN xml_attribute::xml_attribute(xml_attribute_struct* attr): _attr(attr) + { + } + + PUGI__FN static void unspecified_bool_xml_attribute(xml_attribute***) + { + } + + PUGI__FN xml_attribute::operator xml_attribute::unspecified_bool_type() const + { + return _attr ? unspecified_bool_xml_attribute : 0; + } + + PUGI__FN bool xml_attribute::operator!() const + { + return !_attr; + } + + PUGI__FN bool xml_attribute::operator==(const xml_attribute& r) const + { + return (_attr == r._attr); + } + + PUGI__FN bool xml_attribute::operator!=(const xml_attribute& r) const + { + return (_attr != r._attr); + } + + PUGI__FN bool xml_attribute::operator<(const xml_attribute& r) const + { + return (_attr < r._attr); + } + + PUGI__FN bool xml_attribute::operator>(const xml_attribute& r) const + { + return (_attr > r._attr); + } + + PUGI__FN bool xml_attribute::operator<=(const xml_attribute& r) const + { + return (_attr <= r._attr); + } + + PUGI__FN bool xml_attribute::operator>=(const xml_attribute& r) const + { + return (_attr >= r._attr); + } + + PUGI__FN xml_attribute xml_attribute::next_attribute() const + { + return _attr ? xml_attribute(_attr->next_attribute) : xml_attribute(); + } + + PUGI__FN xml_attribute xml_attribute::previous_attribute() const + { + return _attr && _attr->prev_attribute_c->next_attribute ? xml_attribute(_attr->prev_attribute_c) : xml_attribute(); + } + + PUGI__FN const char_t* xml_attribute::as_string(const char_t* def) const + { + return (_attr && _attr->value) ? _attr->value + 0 : def; + } + + PUGI__FN int xml_attribute::as_int(int def) const + { + return (_attr && _attr->value) ? impl::get_value_int(_attr->value) : def; + } + + PUGI__FN unsigned int xml_attribute::as_uint(unsigned int def) const + { + return (_attr && _attr->value) ? impl::get_value_uint(_attr->value) : def; + } + + PUGI__FN double xml_attribute::as_double(double def) const + { + return (_attr && _attr->value) ? impl::get_value_double(_attr->value) : def; + } + + PUGI__FN float xml_attribute::as_float(float def) const + { + return (_attr && _attr->value) ? impl::get_value_float(_attr->value) : def; + } + + PUGI__FN bool xml_attribute::as_bool(bool def) const + { + return (_attr && _attr->value) ? impl::get_value_bool(_attr->value) : def; + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN long long xml_attribute::as_llong(long long def) const + { + return (_attr && _attr->value) ? impl::get_value_llong(_attr->value) : def; + } + + PUGI__FN unsigned long long xml_attribute::as_ullong(unsigned long long def) const + { + return (_attr && _attr->value) ? impl::get_value_ullong(_attr->value) : def; + } +#endif + + PUGI__FN bool xml_attribute::empty() const + { + return !_attr; + } + + PUGI__FN const char_t* xml_attribute::name() const + { + return (_attr && _attr->name) ? _attr->name + 0 : PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* xml_attribute::value() const + { + return (_attr && _attr->value) ? _attr->value + 0 : PUGIXML_TEXT(""); + } + + PUGI__FN size_t xml_attribute::hash_value() const + { + return static_cast(reinterpret_cast(_attr) / sizeof(xml_attribute_struct)); + } + + PUGI__FN xml_attribute_struct* xml_attribute::internal_object() const + { + return _attr; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(const char_t* rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(int rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(unsigned int rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(long rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(unsigned long rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(double rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(float rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(bool rhs) + { + set_value(rhs); + return *this; + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN xml_attribute& xml_attribute::operator=(long long rhs) + { + set_value(rhs); + return *this; + } + + PUGI__FN xml_attribute& xml_attribute::operator=(unsigned long long rhs) + { + set_value(rhs); + return *this; + } +#endif + + PUGI__FN bool xml_attribute::set_name(const char_t* rhs) + { + if (!_attr) return false; + + return impl::strcpy_insitu(_attr->name, _attr->header, impl::xml_memory_page_name_allocated_mask, rhs, impl::strlength(rhs)); + } + + PUGI__FN bool xml_attribute::set_value(const char_t* rhs) + { + if (!_attr) return false; + + return impl::strcpy_insitu(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, impl::strlength(rhs)); + } + + PUGI__FN bool xml_attribute::set_value(int rhs) + { + if (!_attr) return false; + + return impl::set_value_integer(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, rhs < 0); + } + + PUGI__FN bool xml_attribute::set_value(unsigned int rhs) + { + if (!_attr) return false; + + return impl::set_value_integer(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, false); + } + + PUGI__FN bool xml_attribute::set_value(long rhs) + { + if (!_attr) return false; + + return impl::set_value_integer(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, rhs < 0); + } + + PUGI__FN bool xml_attribute::set_value(unsigned long rhs) + { + if (!_attr) return false; + + return impl::set_value_integer(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, false); + } + + PUGI__FN bool xml_attribute::set_value(double rhs) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, default_double_precision); + } + + PUGI__FN bool xml_attribute::set_value(double rhs, int precision) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, precision); + } + + PUGI__FN bool xml_attribute::set_value(float rhs) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, default_float_precision); + } + + PUGI__FN bool xml_attribute::set_value(float rhs, int precision) + { + if (!_attr) return false; + + return impl::set_value_convert(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, precision); + } + + PUGI__FN bool xml_attribute::set_value(bool rhs) + { + if (!_attr) return false; + + return impl::set_value_bool(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs); + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN bool xml_attribute::set_value(long long rhs) + { + if (!_attr) return false; + + return impl::set_value_integer(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, rhs < 0); + } + + PUGI__FN bool xml_attribute::set_value(unsigned long long rhs) + { + if (!_attr) return false; + + return impl::set_value_integer(_attr->value, _attr->header, impl::xml_memory_page_value_allocated_mask, rhs, false); + } +#endif + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xml_attribute& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xml_attribute& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN xml_node::xml_node(): _root(0) + { + } + + PUGI__FN xml_node::xml_node(xml_node_struct* p): _root(p) + { + } + + PUGI__FN static void unspecified_bool_xml_node(xml_node***) + { + } + + PUGI__FN xml_node::operator xml_node::unspecified_bool_type() const + { + return _root ? unspecified_bool_xml_node : 0; + } + + PUGI__FN bool xml_node::operator!() const + { + return !_root; + } + + PUGI__FN xml_node::iterator xml_node::begin() const + { + return iterator(_root ? _root->first_child + 0 : 0, _root); + } + + PUGI__FN xml_node::iterator xml_node::end() const + { + return iterator(0, _root); + } + + PUGI__FN xml_node::attribute_iterator xml_node::attributes_begin() const + { + return attribute_iterator(_root ? _root->first_attribute + 0 : 0, _root); + } + + PUGI__FN xml_node::attribute_iterator xml_node::attributes_end() const + { + return attribute_iterator(0, _root); + } + + PUGI__FN xml_object_range xml_node::children() const + { + return xml_object_range(begin(), end()); + } + + PUGI__FN xml_object_range xml_node::children(const char_t* name_) const + { + return xml_object_range(xml_named_node_iterator(child(name_)._root, _root, name_), xml_named_node_iterator(0, _root, name_)); + } + + PUGI__FN xml_object_range xml_node::attributes() const + { + return xml_object_range(attributes_begin(), attributes_end()); + } + + PUGI__FN bool xml_node::operator==(const xml_node& r) const + { + return (_root == r._root); + } + + PUGI__FN bool xml_node::operator!=(const xml_node& r) const + { + return (_root != r._root); + } + + PUGI__FN bool xml_node::operator<(const xml_node& r) const + { + return (_root < r._root); + } + + PUGI__FN bool xml_node::operator>(const xml_node& r) const + { + return (_root > r._root); + } + + PUGI__FN bool xml_node::operator<=(const xml_node& r) const + { + return (_root <= r._root); + } + + PUGI__FN bool xml_node::operator>=(const xml_node& r) const + { + return (_root >= r._root); + } + + PUGI__FN bool xml_node::empty() const + { + return !_root; + } + + PUGI__FN const char_t* xml_node::name() const + { + return (_root && _root->name) ? _root->name + 0 : PUGIXML_TEXT(""); + } + + PUGI__FN xml_node_type xml_node::type() const + { + return _root ? PUGI__NODETYPE(_root) : node_null; + } + + PUGI__FN const char_t* xml_node::value() const + { + return (_root && _root->value) ? _root->value + 0 : PUGIXML_TEXT(""); + } + + PUGI__FN xml_node xml_node::child(const char_t* name_) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (i->name && impl::strequal(name_, i->name)) return xml_node(i); + + return xml_node(); + } + + PUGI__FN xml_attribute xml_node::attribute(const char_t* name_) const + { + if (!_root) return xml_attribute(); + + for (xml_attribute_struct* i = _root->first_attribute; i; i = i->next_attribute) + if (i->name && impl::strequal(name_, i->name)) + return xml_attribute(i); + + return xml_attribute(); + } + + PUGI__FN xml_node xml_node::next_sibling(const char_t* name_) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->next_sibling; i; i = i->next_sibling) + if (i->name && impl::strequal(name_, i->name)) return xml_node(i); + + return xml_node(); + } + + PUGI__FN xml_node xml_node::next_sibling() const + { + return _root ? xml_node(_root->next_sibling) : xml_node(); + } + + PUGI__FN xml_node xml_node::previous_sibling(const char_t* name_) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->prev_sibling_c; i->next_sibling; i = i->prev_sibling_c) + if (i->name && impl::strequal(name_, i->name)) return xml_node(i); + + return xml_node(); + } + + PUGI__FN xml_attribute xml_node::attribute(const char_t* name_, xml_attribute& hint_) const + { + xml_attribute_struct* hint = hint_._attr; + + // if hint is not an attribute of node, behavior is not defined + assert(!hint || (_root && impl::is_attribute_of(hint, _root))); + + if (!_root) return xml_attribute(); + + // optimistically search from hint up until the end + for (xml_attribute_struct* i = hint; i; i = i->next_attribute) + if (i->name && impl::strequal(name_, i->name)) + { + // update hint to maximize efficiency of searching for consecutive attributes + hint_._attr = i->next_attribute; + + return xml_attribute(i); + } + + // wrap around and search from the first attribute until the hint + // 'j' null pointer check is technically redundant, but it prevents a crash in case the assertion above fails + for (xml_attribute_struct* j = _root->first_attribute; j && j != hint; j = j->next_attribute) + if (j->name && impl::strequal(name_, j->name)) + { + // update hint to maximize efficiency of searching for consecutive attributes + hint_._attr = j->next_attribute; + + return xml_attribute(j); + } + + return xml_attribute(); + } + + PUGI__FN xml_node xml_node::previous_sibling() const + { + if (!_root) return xml_node(); + + if (_root->prev_sibling_c->next_sibling) return xml_node(_root->prev_sibling_c); + else return xml_node(); + } + + PUGI__FN xml_node xml_node::parent() const + { + return _root ? xml_node(_root->parent) : xml_node(); + } + + PUGI__FN xml_node xml_node::root() const + { + return _root ? xml_node(&impl::get_document(_root)) : xml_node(); + } + + PUGI__FN xml_text xml_node::text() const + { + return xml_text(_root); + } + + PUGI__FN const char_t* xml_node::child_value() const + { + if (!_root) return PUGIXML_TEXT(""); + + // element nodes can have value if parse_embed_pcdata was used + if (PUGI__NODETYPE(_root) == node_element && _root->value) + return _root->value; + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (impl::is_text_node(i) && i->value) + return i->value; + + return PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* xml_node::child_value(const char_t* name_) const + { + return child(name_).child_value(); + } + + PUGI__FN xml_attribute xml_node::first_attribute() const + { + return _root ? xml_attribute(_root->first_attribute) : xml_attribute(); + } + + PUGI__FN xml_attribute xml_node::last_attribute() const + { + return _root && _root->first_attribute ? xml_attribute(_root->first_attribute->prev_attribute_c) : xml_attribute(); + } + + PUGI__FN xml_node xml_node::first_child() const + { + return _root ? xml_node(_root->first_child) : xml_node(); + } + + PUGI__FN xml_node xml_node::last_child() const + { + return _root && _root->first_child ? xml_node(_root->first_child->prev_sibling_c) : xml_node(); + } + + PUGI__FN bool xml_node::set_name(const char_t* rhs) + { + xml_node_type type_ = _root ? PUGI__NODETYPE(_root) : node_null; + + if (type_ != node_element && type_ != node_pi && type_ != node_declaration) + return false; + + return impl::strcpy_insitu(_root->name, _root->header, impl::xml_memory_page_name_allocated_mask, rhs, impl::strlength(rhs)); + } + + PUGI__FN bool xml_node::set_value(const char_t* rhs) + { + xml_node_type type_ = _root ? PUGI__NODETYPE(_root) : node_null; + + if (type_ != node_pcdata && type_ != node_cdata && type_ != node_comment && type_ != node_pi && type_ != node_doctype) + return false; + + return impl::strcpy_insitu(_root->value, _root->header, impl::xml_memory_page_value_allocated_mask, rhs, impl::strlength(rhs)); + } + + PUGI__FN xml_attribute xml_node::append_attribute(const char_t* name_) + { + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::append_attribute(a._attr, _root); + + a.set_name(name_); + + return a; + } + + PUGI__FN xml_attribute xml_node::prepend_attribute(const char_t* name_) + { + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::prepend_attribute(a._attr, _root); + + a.set_name(name_); + + return a; + } + + PUGI__FN xml_attribute xml_node::insert_attribute_after(const char_t* name_, const xml_attribute& attr) + { + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + if (!attr || !impl::is_attribute_of(attr._attr, _root)) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::insert_attribute_after(a._attr, attr._attr, _root); + + a.set_name(name_); + + return a; + } + + PUGI__FN xml_attribute xml_node::insert_attribute_before(const char_t* name_, const xml_attribute& attr) + { + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + if (!attr || !impl::is_attribute_of(attr._attr, _root)) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::insert_attribute_before(a._attr, attr._attr, _root); + + a.set_name(name_); + + return a; + } + + PUGI__FN xml_attribute xml_node::append_copy(const xml_attribute& proto) + { + if (!proto) return xml_attribute(); + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::append_attribute(a._attr, _root); + impl::node_copy_attribute(a._attr, proto._attr); + + return a; + } + + PUGI__FN xml_attribute xml_node::prepend_copy(const xml_attribute& proto) + { + if (!proto) return xml_attribute(); + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::prepend_attribute(a._attr, _root); + impl::node_copy_attribute(a._attr, proto._attr); + + return a; + } + + PUGI__FN xml_attribute xml_node::insert_copy_after(const xml_attribute& proto, const xml_attribute& attr) + { + if (!proto) return xml_attribute(); + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + if (!attr || !impl::is_attribute_of(attr._attr, _root)) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::insert_attribute_after(a._attr, attr._attr, _root); + impl::node_copy_attribute(a._attr, proto._attr); + + return a; + } + + PUGI__FN xml_attribute xml_node::insert_copy_before(const xml_attribute& proto, const xml_attribute& attr) + { + if (!proto) return xml_attribute(); + if (!impl::allow_insert_attribute(type())) return xml_attribute(); + if (!attr || !impl::is_attribute_of(attr._attr, _root)) return xml_attribute(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_attribute(); + + xml_attribute a(impl::allocate_attribute(alloc)); + if (!a) return xml_attribute(); + + impl::insert_attribute_before(a._attr, attr._attr, _root); + impl::node_copy_attribute(a._attr, proto._attr); + + return a; + } + + PUGI__FN xml_node xml_node::append_child(xml_node_type type_) + { + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::append_node(n._root, _root); + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::prepend_child(xml_node_type type_) + { + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::prepend_node(n._root, _root); + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::insert_child_before(xml_node_type type_, const xml_node& node) + { + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::insert_node_before(n._root, node._root); + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::insert_child_after(xml_node_type type_, const xml_node& node) + { + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::insert_node_after(n._root, node._root); + + if (type_ == node_declaration) n.set_name(PUGIXML_TEXT("xml")); + + return n; + } + + PUGI__FN xml_node xml_node::append_child(const char_t* name_) + { + xml_node result = append_child(node_element); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::prepend_child(const char_t* name_) + { + xml_node result = prepend_child(node_element); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::insert_child_after(const char_t* name_, const xml_node& node) + { + xml_node result = insert_child_after(node_element, node); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::insert_child_before(const char_t* name_, const xml_node& node) + { + xml_node result = insert_child_before(node_element, node); + + result.set_name(name_); + + return result; + } + + PUGI__FN xml_node xml_node::append_copy(const xml_node& proto) + { + xml_node_type type_ = proto.type(); + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::append_node(n._root, _root); + impl::node_copy_tree(n._root, proto._root); + + return n; + } + + PUGI__FN xml_node xml_node::prepend_copy(const xml_node& proto) + { + xml_node_type type_ = proto.type(); + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::prepend_node(n._root, _root); + impl::node_copy_tree(n._root, proto._root); + + return n; + } + + PUGI__FN xml_node xml_node::insert_copy_after(const xml_node& proto, const xml_node& node) + { + xml_node_type type_ = proto.type(); + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::insert_node_after(n._root, node._root); + impl::node_copy_tree(n._root, proto._root); + + return n; + } + + PUGI__FN xml_node xml_node::insert_copy_before(const xml_node& proto, const xml_node& node) + { + xml_node_type type_ = proto.type(); + if (!impl::allow_insert_child(type(), type_)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + xml_node n(impl::allocate_node(alloc, type_)); + if (!n) return xml_node(); + + impl::insert_node_before(n._root, node._root); + impl::node_copy_tree(n._root, proto._root); + + return n; + } + + PUGI__FN xml_node xml_node::append_move(const xml_node& moved) + { + if (!impl::allow_move(*this, moved)) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + // disable document_buffer_order optimization since moving nodes around changes document order without changing buffer pointers + impl::get_document(_root).header |= impl::xml_memory_page_contents_shared_mask; + + impl::remove_node(moved._root); + impl::append_node(moved._root, _root); + + return moved; + } + + PUGI__FN xml_node xml_node::prepend_move(const xml_node& moved) + { + if (!impl::allow_move(*this, moved)) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + // disable document_buffer_order optimization since moving nodes around changes document order without changing buffer pointers + impl::get_document(_root).header |= impl::xml_memory_page_contents_shared_mask; + + impl::remove_node(moved._root); + impl::prepend_node(moved._root, _root); + + return moved; + } + + PUGI__FN xml_node xml_node::insert_move_after(const xml_node& moved, const xml_node& node) + { + if (!impl::allow_move(*this, moved)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + if (moved._root == node._root) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + // disable document_buffer_order optimization since moving nodes around changes document order without changing buffer pointers + impl::get_document(_root).header |= impl::xml_memory_page_contents_shared_mask; + + impl::remove_node(moved._root); + impl::insert_node_after(moved._root, node._root); + + return moved; + } + + PUGI__FN xml_node xml_node::insert_move_before(const xml_node& moved, const xml_node& node) + { + if (!impl::allow_move(*this, moved)) return xml_node(); + if (!node._root || node._root->parent != _root) return xml_node(); + if (moved._root == node._root) return xml_node(); + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return xml_node(); + + // disable document_buffer_order optimization since moving nodes around changes document order without changing buffer pointers + impl::get_document(_root).header |= impl::xml_memory_page_contents_shared_mask; + + impl::remove_node(moved._root); + impl::insert_node_before(moved._root, node._root); + + return moved; + } + + PUGI__FN bool xml_node::remove_attribute(const char_t* name_) + { + return remove_attribute(attribute(name_)); + } + + PUGI__FN bool xml_node::remove_attribute(const xml_attribute& a) + { + if (!_root || !a._attr) return false; + if (!impl::is_attribute_of(a._attr, _root)) return false; + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return false; + + impl::remove_attribute(a._attr, _root); + impl::destroy_attribute(a._attr, alloc); + + return true; + } + + PUGI__FN bool xml_node::remove_attributes() + { + if (!_root) return false; + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return false; + + for (xml_attribute_struct* attr = _root->first_attribute; attr; ) + { + xml_attribute_struct* next = attr->next_attribute; + + impl::destroy_attribute(attr, alloc); + + attr = next; + } + + _root->first_attribute = 0; + + return true; + } + + PUGI__FN bool xml_node::remove_child(const char_t* name_) + { + return remove_child(child(name_)); + } + + PUGI__FN bool xml_node::remove_child(const xml_node& n) + { + if (!_root || !n._root || n._root->parent != _root) return false; + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return false; + + impl::remove_node(n._root); + impl::destroy_node(n._root, alloc); + + return true; + } + + PUGI__FN bool xml_node::remove_children() + { + if (!_root) return false; + + impl::xml_allocator& alloc = impl::get_allocator(_root); + if (!alloc.reserve()) return false; + + for (xml_node_struct* cur = _root->first_child; cur; ) + { + xml_node_struct* next = cur->next_sibling; + + impl::destroy_node(cur, alloc); + + cur = next; + } + + _root->first_child = 0; + + return true; + } + + PUGI__FN xml_parse_result xml_node::append_buffer(const void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + // append_buffer is only valid for elements/documents + if (!impl::allow_insert_child(type(), node_element)) return impl::make_parse_result(status_append_invalid_root); + + // get document node + impl::xml_document_struct* doc = &impl::get_document(_root); + + // disable document_buffer_order optimization since in a document with multiple buffers comparing buffer pointers does not make sense + doc->header |= impl::xml_memory_page_contents_shared_mask; + + // get extra buffer element (we'll store the document fragment buffer there so that we can deallocate it later) + impl::xml_memory_page* page = 0; + impl::xml_extra_buffer* extra = static_cast(doc->allocate_memory(sizeof(impl::xml_extra_buffer) + sizeof(void*), page)); + (void)page; + + if (!extra) return impl::make_parse_result(status_out_of_memory); + + #ifdef PUGIXML_COMPACT + // align the memory block to a pointer boundary; this is required for compact mode where memory allocations are only 4b aligned + // note that this requires up to sizeof(void*)-1 additional memory, which the allocation above takes into account + extra = reinterpret_cast((reinterpret_cast(extra) + (sizeof(void*) - 1)) & ~(sizeof(void*) - 1)); + #endif + + // add extra buffer to the list + extra->buffer = 0; + extra->next = doc->extra_buffers; + doc->extra_buffers = extra; + + // name of the root has to be NULL before parsing - otherwise closing node mismatches will not be detected at the top level + impl::name_null_sentry sentry(_root); + + return impl::load_buffer_impl(doc, _root, const_cast(contents), size, options, encoding, false, false, &extra->buffer); + } + + PUGI__FN xml_node xml_node::find_child_by_attribute(const char_t* name_, const char_t* attr_name, const char_t* attr_value) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (i->name && impl::strequal(name_, i->name)) + { + for (xml_attribute_struct* a = i->first_attribute; a; a = a->next_attribute) + if (a->name && impl::strequal(attr_name, a->name) && impl::strequal(attr_value, a->value ? a->value + 0 : PUGIXML_TEXT(""))) + return xml_node(i); + } + + return xml_node(); + } + + PUGI__FN xml_node xml_node::find_child_by_attribute(const char_t* attr_name, const char_t* attr_value) const + { + if (!_root) return xml_node(); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + for (xml_attribute_struct* a = i->first_attribute; a; a = a->next_attribute) + if (a->name && impl::strequal(attr_name, a->name) && impl::strequal(attr_value, a->value ? a->value + 0 : PUGIXML_TEXT(""))) + return xml_node(i); + + return xml_node(); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN string_t xml_node::path(char_t delimiter) const + { + if (!_root) return string_t(); + + size_t offset = 0; + + for (xml_node_struct* i = _root; i; i = i->parent) + { + offset += (i != _root); + offset += i->name ? impl::strlength(i->name) : 0; + } + + string_t result; + result.resize(offset); + + for (xml_node_struct* j = _root; j; j = j->parent) + { + if (j != _root) + result[--offset] = delimiter; + + if (j->name) + { + size_t length = impl::strlength(j->name); + + offset -= length; + memcpy(&result[offset], j->name, length * sizeof(char_t)); + } + } + + assert(offset == 0); + + return result; + } +#endif + + PUGI__FN xml_node xml_node::first_element_by_path(const char_t* path_, char_t delimiter) const + { + xml_node context = path_[0] == delimiter ? root() : *this; + + if (!context._root) return xml_node(); + + const char_t* path_segment = path_; + + while (*path_segment == delimiter) ++path_segment; + + const char_t* path_segment_end = path_segment; + + while (*path_segment_end && *path_segment_end != delimiter) ++path_segment_end; + + if (path_segment == path_segment_end) return context; + + const char_t* next_segment = path_segment_end; + + while (*next_segment == delimiter) ++next_segment; + + if (*path_segment == '.' && path_segment + 1 == path_segment_end) + return context.first_element_by_path(next_segment, delimiter); + else if (*path_segment == '.' && *(path_segment+1) == '.' && path_segment + 2 == path_segment_end) + return context.parent().first_element_by_path(next_segment, delimiter); + else + { + for (xml_node_struct* j = context._root->first_child; j; j = j->next_sibling) + { + if (j->name && impl::strequalrange(j->name, path_segment, static_cast(path_segment_end - path_segment))) + { + xml_node subsearch = xml_node(j).first_element_by_path(next_segment, delimiter); + + if (subsearch) return subsearch; + } + } + + return xml_node(); + } + } + + PUGI__FN bool xml_node::traverse(xml_tree_walker& walker) + { + walker._depth = -1; + + xml_node arg_begin(_root); + if (!walker.begin(arg_begin)) return false; + + xml_node_struct* cur = _root ? _root->first_child + 0 : 0; + + if (cur) + { + ++walker._depth; + + do + { + xml_node arg_for_each(cur); + if (!walker.for_each(arg_for_each)) + return false; + + if (cur->first_child) + { + ++walker._depth; + cur = cur->first_child; + } + else if (cur->next_sibling) + cur = cur->next_sibling; + else + { + while (!cur->next_sibling && cur != _root && cur->parent) + { + --walker._depth; + cur = cur->parent; + } + + if (cur != _root) + cur = cur->next_sibling; + } + } + while (cur && cur != _root); + } + + assert(walker._depth == -1); + + xml_node arg_end(_root); + return walker.end(arg_end); + } + + PUGI__FN size_t xml_node::hash_value() const + { + return static_cast(reinterpret_cast(_root) / sizeof(xml_node_struct)); + } + + PUGI__FN xml_node_struct* xml_node::internal_object() const + { + return _root; + } + + PUGI__FN void xml_node::print(xml_writer& writer, const char_t* indent, unsigned int flags, xml_encoding encoding, unsigned int depth) const + { + if (!_root) return; + + impl::xml_buffered_writer buffered_writer(writer, encoding); + + impl::node_output(buffered_writer, _root, indent, flags, depth); + + buffered_writer.flush(); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN void xml_node::print(std::basic_ostream >& stream, const char_t* indent, unsigned int flags, xml_encoding encoding, unsigned int depth) const + { + xml_writer_stream writer(stream); + + print(writer, indent, flags, encoding, depth); + } + + PUGI__FN void xml_node::print(std::basic_ostream >& stream, const char_t* indent, unsigned int flags, unsigned int depth) const + { + xml_writer_stream writer(stream); + + print(writer, indent, flags, encoding_wchar, depth); + } +#endif + + PUGI__FN ptrdiff_t xml_node::offset_debug() const + { + if (!_root) return -1; + + impl::xml_document_struct& doc = impl::get_document(_root); + + // we can determine the offset reliably only if there is exactly once parse buffer + if (!doc.buffer || doc.extra_buffers) return -1; + + switch (type()) + { + case node_document: + return 0; + + case node_element: + case node_declaration: + case node_pi: + return _root->name && (_root->header & impl::xml_memory_page_name_allocated_or_shared_mask) == 0 ? _root->name - doc.buffer : -1; + + case node_pcdata: + case node_cdata: + case node_comment: + case node_doctype: + return _root->value && (_root->header & impl::xml_memory_page_value_allocated_or_shared_mask) == 0 ? _root->value - doc.buffer : -1; + + default: + assert(false && "Invalid node type"); // unreachable + return -1; + } + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xml_node& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xml_node& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN xml_text::xml_text(xml_node_struct* root): _root(root) + { + } + + PUGI__FN xml_node_struct* xml_text::_data() const + { + if (!_root || impl::is_text_node(_root)) return _root; + + // element nodes can have value if parse_embed_pcdata was used + if (PUGI__NODETYPE(_root) == node_element && _root->value) + return _root; + + for (xml_node_struct* node = _root->first_child; node; node = node->next_sibling) + if (impl::is_text_node(node)) + return node; + + return 0; + } + + PUGI__FN xml_node_struct* xml_text::_data_new() + { + xml_node_struct* d = _data(); + if (d) return d; + + return xml_node(_root).append_child(node_pcdata).internal_object(); + } + + PUGI__FN xml_text::xml_text(): _root(0) + { + } + + PUGI__FN static void unspecified_bool_xml_text(xml_text***) + { + } + + PUGI__FN xml_text::operator xml_text::unspecified_bool_type() const + { + return _data() ? unspecified_bool_xml_text : 0; + } + + PUGI__FN bool xml_text::operator!() const + { + return !_data(); + } + + PUGI__FN bool xml_text::empty() const + { + return _data() == 0; + } + + PUGI__FN const char_t* xml_text::get() const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? d->value + 0 : PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* xml_text::as_string(const char_t* def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? d->value + 0 : def; + } + + PUGI__FN int xml_text::as_int(int def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_int(d->value) : def; + } + + PUGI__FN unsigned int xml_text::as_uint(unsigned int def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_uint(d->value) : def; + } + + PUGI__FN double xml_text::as_double(double def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_double(d->value) : def; + } + + PUGI__FN float xml_text::as_float(float def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_float(d->value) : def; + } + + PUGI__FN bool xml_text::as_bool(bool def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_bool(d->value) : def; + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN long long xml_text::as_llong(long long def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_llong(d->value) : def; + } + + PUGI__FN unsigned long long xml_text::as_ullong(unsigned long long def) const + { + xml_node_struct* d = _data(); + + return (d && d->value) ? impl::get_value_ullong(d->value) : def; + } +#endif + + PUGI__FN bool xml_text::set(const char_t* rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::strcpy_insitu(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, impl::strlength(rhs)) : false; + } + + PUGI__FN bool xml_text::set(int rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_integer(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, rhs < 0) : false; + } + + PUGI__FN bool xml_text::set(unsigned int rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_integer(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, false) : false; + } + + PUGI__FN bool xml_text::set(long rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_integer(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, rhs < 0) : false; + } + + PUGI__FN bool xml_text::set(unsigned long rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_integer(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, false) : false; + } + + PUGI__FN bool xml_text::set(float rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, default_float_precision) : false; + } + + PUGI__FN bool xml_text::set(float rhs, int precision) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, precision) : false; + } + + PUGI__FN bool xml_text::set(double rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, default_double_precision) : false; + } + + PUGI__FN bool xml_text::set(double rhs, int precision) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_convert(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, precision) : false; + } + + PUGI__FN bool xml_text::set(bool rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_bool(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs) : false; + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN bool xml_text::set(long long rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_integer(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, rhs < 0) : false; + } + + PUGI__FN bool xml_text::set(unsigned long long rhs) + { + xml_node_struct* dn = _data_new(); + + return dn ? impl::set_value_integer(dn->value, dn->header, impl::xml_memory_page_value_allocated_mask, rhs, false) : false; + } +#endif + + PUGI__FN xml_text& xml_text::operator=(const char_t* rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(int rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(unsigned int rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(long rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(unsigned long rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(double rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(float rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(bool rhs) + { + set(rhs); + return *this; + } + +#ifdef PUGIXML_HAS_LONG_LONG + PUGI__FN xml_text& xml_text::operator=(long long rhs) + { + set(rhs); + return *this; + } + + PUGI__FN xml_text& xml_text::operator=(unsigned long long rhs) + { + set(rhs); + return *this; + } +#endif + + PUGI__FN xml_node xml_text::data() const + { + return xml_node(_data()); + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xml_text& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xml_text& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN xml_node_iterator::xml_node_iterator() + { + } + + PUGI__FN xml_node_iterator::xml_node_iterator(const xml_node& node): _wrap(node), _parent(node.parent()) + { + } + + PUGI__FN xml_node_iterator::xml_node_iterator(xml_node_struct* ref, xml_node_struct* parent): _wrap(ref), _parent(parent) + { + } + + PUGI__FN bool xml_node_iterator::operator==(const xml_node_iterator& rhs) const + { + return _wrap._root == rhs._wrap._root && _parent._root == rhs._parent._root; + } + + PUGI__FN bool xml_node_iterator::operator!=(const xml_node_iterator& rhs) const + { + return _wrap._root != rhs._wrap._root || _parent._root != rhs._parent._root; + } + + PUGI__FN xml_node& xml_node_iterator::operator*() const + { + assert(_wrap._root); + return _wrap; + } + + PUGI__FN xml_node* xml_node_iterator::operator->() const + { + assert(_wrap._root); + return const_cast(&_wrap); // BCC5 workaround + } + + PUGI__FN xml_node_iterator& xml_node_iterator::operator++() + { + assert(_wrap._root); + _wrap._root = _wrap._root->next_sibling; + return *this; + } + + PUGI__FN xml_node_iterator xml_node_iterator::operator++(int) + { + xml_node_iterator temp = *this; + ++*this; + return temp; + } + + PUGI__FN xml_node_iterator& xml_node_iterator::operator--() + { + _wrap = _wrap._root ? _wrap.previous_sibling() : _parent.last_child(); + return *this; + } + + PUGI__FN xml_node_iterator xml_node_iterator::operator--(int) + { + xml_node_iterator temp = *this; + --*this; + return temp; + } + + PUGI__FN xml_attribute_iterator::xml_attribute_iterator() + { + } + + PUGI__FN xml_attribute_iterator::xml_attribute_iterator(const xml_attribute& attr, const xml_node& parent): _wrap(attr), _parent(parent) + { + } + + PUGI__FN xml_attribute_iterator::xml_attribute_iterator(xml_attribute_struct* ref, xml_node_struct* parent): _wrap(ref), _parent(parent) + { + } + + PUGI__FN bool xml_attribute_iterator::operator==(const xml_attribute_iterator& rhs) const + { + return _wrap._attr == rhs._wrap._attr && _parent._root == rhs._parent._root; + } + + PUGI__FN bool xml_attribute_iterator::operator!=(const xml_attribute_iterator& rhs) const + { + return _wrap._attr != rhs._wrap._attr || _parent._root != rhs._parent._root; + } + + PUGI__FN xml_attribute& xml_attribute_iterator::operator*() const + { + assert(_wrap._attr); + return _wrap; + } + + PUGI__FN xml_attribute* xml_attribute_iterator::operator->() const + { + assert(_wrap._attr); + return const_cast(&_wrap); // BCC5 workaround + } + + PUGI__FN xml_attribute_iterator& xml_attribute_iterator::operator++() + { + assert(_wrap._attr); + _wrap._attr = _wrap._attr->next_attribute; + return *this; + } + + PUGI__FN xml_attribute_iterator xml_attribute_iterator::operator++(int) + { + xml_attribute_iterator temp = *this; + ++*this; + return temp; + } + + PUGI__FN xml_attribute_iterator& xml_attribute_iterator::operator--() + { + _wrap = _wrap._attr ? _wrap.previous_attribute() : _parent.last_attribute(); + return *this; + } + + PUGI__FN xml_attribute_iterator xml_attribute_iterator::operator--(int) + { + xml_attribute_iterator temp = *this; + --*this; + return temp; + } + + PUGI__FN xml_named_node_iterator::xml_named_node_iterator(): _name(0) + { + } + + PUGI__FN xml_named_node_iterator::xml_named_node_iterator(const xml_node& node, const char_t* name): _wrap(node), _parent(node.parent()), _name(name) + { + } + + PUGI__FN xml_named_node_iterator::xml_named_node_iterator(xml_node_struct* ref, xml_node_struct* parent, const char_t* name): _wrap(ref), _parent(parent), _name(name) + { + } + + PUGI__FN bool xml_named_node_iterator::operator==(const xml_named_node_iterator& rhs) const + { + return _wrap._root == rhs._wrap._root && _parent._root == rhs._parent._root; + } + + PUGI__FN bool xml_named_node_iterator::operator!=(const xml_named_node_iterator& rhs) const + { + return _wrap._root != rhs._wrap._root || _parent._root != rhs._parent._root; + } + + PUGI__FN xml_node& xml_named_node_iterator::operator*() const + { + assert(_wrap._root); + return _wrap; + } + + PUGI__FN xml_node* xml_named_node_iterator::operator->() const + { + assert(_wrap._root); + return const_cast(&_wrap); // BCC5 workaround + } + + PUGI__FN xml_named_node_iterator& xml_named_node_iterator::operator++() + { + assert(_wrap._root); + _wrap = _wrap.next_sibling(_name); + return *this; + } + + PUGI__FN xml_named_node_iterator xml_named_node_iterator::operator++(int) + { + xml_named_node_iterator temp = *this; + ++*this; + return temp; + } + + PUGI__FN xml_named_node_iterator& xml_named_node_iterator::operator--() + { + if (_wrap._root) + _wrap = _wrap.previous_sibling(_name); + else + { + _wrap = _parent.last_child(); + + if (!impl::strequal(_wrap.name(), _name)) + _wrap = _wrap.previous_sibling(_name); + } + + return *this; + } + + PUGI__FN xml_named_node_iterator xml_named_node_iterator::operator--(int) + { + xml_named_node_iterator temp = *this; + --*this; + return temp; + } + + PUGI__FN xml_parse_result::xml_parse_result(): status(status_internal_error), offset(0), encoding(encoding_auto) + { + } + + PUGI__FN xml_parse_result::operator bool() const + { + return status == status_ok; + } + + PUGI__FN const char* xml_parse_result::description() const + { + switch (status) + { + case status_ok: return "No error"; + + case status_file_not_found: return "File was not found"; + case status_io_error: return "Error reading from file/stream"; + case status_out_of_memory: return "Could not allocate memory"; + case status_internal_error: return "Internal error occurred"; + + case status_unrecognized_tag: return "Could not determine tag type"; + + case status_bad_pi: return "Error parsing document declaration/processing instruction"; + case status_bad_comment: return "Error parsing comment"; + case status_bad_cdata: return "Error parsing CDATA section"; + case status_bad_doctype: return "Error parsing document type declaration"; + case status_bad_pcdata: return "Error parsing PCDATA section"; + case status_bad_start_element: return "Error parsing start element tag"; + case status_bad_attribute: return "Error parsing element attribute"; + case status_bad_end_element: return "Error parsing end element tag"; + case status_end_element_mismatch: return "Start-end tags mismatch"; + + case status_append_invalid_root: return "Unable to append nodes: root is not an element or document"; + + case status_no_document_element: return "No document element found"; + + default: return "Unknown error"; + } + } + + PUGI__FN xml_document::xml_document(): _buffer(0) + { + _create(); + } + + PUGI__FN xml_document::~xml_document() + { + _destroy(); + } + +#ifdef PUGIXML_HAS_MOVE + PUGI__FN xml_document::xml_document(xml_document&& rhs) PUGIXML_NOEXCEPT_IF_NOT_COMPACT: _buffer(0) + { + _create(); + _move(rhs); + } + + PUGI__FN xml_document& xml_document::operator=(xml_document&& rhs) PUGIXML_NOEXCEPT_IF_NOT_COMPACT + { + if (this == &rhs) return *this; + + _destroy(); + _create(); + _move(rhs); + + return *this; + } +#endif + + PUGI__FN void xml_document::reset() + { + _destroy(); + _create(); + } + + PUGI__FN void xml_document::reset(const xml_document& proto) + { + reset(); + + impl::node_copy_tree(_root, proto._root); + } + + PUGI__FN void xml_document::_create() + { + assert(!_root); + + #ifdef PUGIXML_COMPACT + // space for page marker for the first page (uint32_t), rounded up to pointer size; assumes pointers are at least 32-bit + const size_t page_offset = sizeof(void*); + #else + const size_t page_offset = 0; + #endif + + // initialize sentinel page + PUGI__STATIC_ASSERT(sizeof(impl::xml_memory_page) + sizeof(impl::xml_document_struct) + page_offset <= sizeof(_memory)); + + // prepare page structure + impl::xml_memory_page* page = impl::xml_memory_page::construct(_memory); + assert(page); + + page->busy_size = impl::xml_memory_page_size; + + // setup first page marker + #ifdef PUGIXML_COMPACT + // round-trip through void* to avoid 'cast increases required alignment of target type' warning + page->compact_page_marker = reinterpret_cast(static_cast(reinterpret_cast(page) + sizeof(impl::xml_memory_page))); + *page->compact_page_marker = sizeof(impl::xml_memory_page); + #endif + + // allocate new root + _root = new (reinterpret_cast(page) + sizeof(impl::xml_memory_page) + page_offset) impl::xml_document_struct(page); + _root->prev_sibling_c = _root; + + // setup sentinel page + page->allocator = static_cast(_root); + + // setup hash table pointer in allocator + #ifdef PUGIXML_COMPACT + page->allocator->_hash = &static_cast(_root)->hash; + #endif + + // verify the document allocation + assert(reinterpret_cast(_root) + sizeof(impl::xml_document_struct) <= _memory + sizeof(_memory)); + } + + PUGI__FN void xml_document::_destroy() + { + assert(_root); + + // destroy static storage + if (_buffer) + { + impl::xml_memory::deallocate(_buffer); + _buffer = 0; + } + + // destroy extra buffers (note: no need to destroy linked list nodes, they're allocated using document allocator) + for (impl::xml_extra_buffer* extra = static_cast(_root)->extra_buffers; extra; extra = extra->next) + { + if (extra->buffer) impl::xml_memory::deallocate(extra->buffer); + } + + // destroy dynamic storage, leave sentinel page (it's in static memory) + impl::xml_memory_page* root_page = PUGI__GETPAGE(_root); + assert(root_page && !root_page->prev); + assert(reinterpret_cast(root_page) >= _memory && reinterpret_cast(root_page) < _memory + sizeof(_memory)); + + for (impl::xml_memory_page* page = root_page->next; page; ) + { + impl::xml_memory_page* next = page->next; + + impl::xml_allocator::deallocate_page(page); + + page = next; + } + + #ifdef PUGIXML_COMPACT + // destroy hash table + static_cast(_root)->hash.clear(); + #endif + + _root = 0; + } + +#ifdef PUGIXML_HAS_MOVE + PUGI__FN void xml_document::_move(xml_document& rhs) PUGIXML_NOEXCEPT_IF_NOT_COMPACT + { + impl::xml_document_struct* doc = static_cast(_root); + impl::xml_document_struct* other = static_cast(rhs._root); + + // save first child pointer for later; this needs hash access + xml_node_struct* other_first_child = other->first_child; + + #ifdef PUGIXML_COMPACT + // reserve space for the hash table up front; this is the only operation that can fail + // if it does, we have no choice but to throw (if we have exceptions) + if (other_first_child) + { + size_t other_children = 0; + for (xml_node_struct* node = other_first_child; node; node = node->next_sibling) + other_children++; + + // in compact mode, each pointer assignment could result in a hash table request + // during move, we have to relocate document first_child and parents of all children + // normally there's just one child and its parent has a pointerless encoding but + // we assume the worst here + if (!other->_hash->reserve(other_children + 1)) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return; + #else + throw std::bad_alloc(); + #endif + } + } + #endif + + // move allocation state + // note that other->_root may point to the embedded document page, in which case we should keep original (empty) state + if (other->_root != PUGI__GETPAGE(other)) + { + doc->_root = other->_root; + doc->_busy_size = other->_busy_size; + } + + // move buffer state + doc->buffer = other->buffer; + doc->extra_buffers = other->extra_buffers; + _buffer = rhs._buffer; + + #ifdef PUGIXML_COMPACT + // move compact hash; note that the hash table can have pointers to other but they will be "inactive", similarly to nodes removed with remove_child + doc->hash = other->hash; + doc->_hash = &doc->hash; + + // make sure we don't access other hash up until the end when we reinitialize other document + other->_hash = 0; + #endif + + // move page structure + impl::xml_memory_page* doc_page = PUGI__GETPAGE(doc); + assert(doc_page && !doc_page->prev && !doc_page->next); + + impl::xml_memory_page* other_page = PUGI__GETPAGE(other); + assert(other_page && !other_page->prev); + + // relink pages since root page is embedded into xml_document + if (impl::xml_memory_page* page = other_page->next) + { + assert(page->prev == other_page); + + page->prev = doc_page; + + doc_page->next = page; + other_page->next = 0; + } + + // make sure pages point to the correct document state + for (impl::xml_memory_page* page = doc_page->next; page; page = page->next) + { + assert(page->allocator == other); + + page->allocator = doc; + + #ifdef PUGIXML_COMPACT + // this automatically migrates most children between documents and prevents ->parent assignment from allocating + if (page->compact_shared_parent == other) + page->compact_shared_parent = doc; + #endif + } + + // move tree structure + assert(!doc->first_child); + + doc->first_child = other_first_child; + + for (xml_node_struct* node = other_first_child; node; node = node->next_sibling) + { + #ifdef PUGIXML_COMPACT + // most children will have migrated when we reassigned compact_shared_parent + assert(node->parent == other || node->parent == doc); + + node->parent = doc; + #else + assert(node->parent == other); + node->parent = doc; + #endif + } + + // reset other document + new (other) impl::xml_document_struct(PUGI__GETPAGE(other)); + rhs._buffer = 0; + } +#endif + +#ifndef PUGIXML_NO_STL + PUGI__FN xml_parse_result xml_document::load(std::basic_istream >& stream, unsigned int options, xml_encoding encoding) + { + reset(); + + return impl::load_stream_impl(static_cast(_root), stream, options, encoding, &_buffer); + } + + PUGI__FN xml_parse_result xml_document::load(std::basic_istream >& stream, unsigned int options) + { + reset(); + + return impl::load_stream_impl(static_cast(_root), stream, options, encoding_wchar, &_buffer); + } +#endif + + PUGI__FN xml_parse_result xml_document::load_string(const char_t* contents, unsigned int options) + { + // Force native encoding (skip autodetection) + #ifdef PUGIXML_WCHAR_MODE + xml_encoding encoding = encoding_wchar; + #else + xml_encoding encoding = encoding_utf8; + #endif + + return load_buffer(contents, impl::strlength(contents) * sizeof(char_t), options, encoding); + } + + PUGI__FN xml_parse_result xml_document::load(const char_t* contents, unsigned int options) + { + return load_string(contents, options); + } + + PUGI__FN xml_parse_result xml_document::load_file(const char* path_, unsigned int options, xml_encoding encoding) + { + reset(); + + using impl::auto_deleter; // MSVC7 workaround + auto_deleter file(impl::open_file(path_, "rb"), impl::close_file); + + return impl::load_file_impl(static_cast(_root), file.data, options, encoding, &_buffer); + } + + PUGI__FN xml_parse_result xml_document::load_file(const wchar_t* path_, unsigned int options, xml_encoding encoding) + { + reset(); + + using impl::auto_deleter; // MSVC7 workaround + auto_deleter file(impl::open_file_wide(path_, L"rb"), impl::close_file); + + return impl::load_file_impl(static_cast(_root), file.data, options, encoding, &_buffer); + } + + PUGI__FN xml_parse_result xml_document::load_buffer(const void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + reset(); + + return impl::load_buffer_impl(static_cast(_root), _root, const_cast(contents), size, options, encoding, false, false, &_buffer); + } + + PUGI__FN xml_parse_result xml_document::load_buffer_inplace(void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + reset(); + + return impl::load_buffer_impl(static_cast(_root), _root, contents, size, options, encoding, true, false, &_buffer); + } + + PUGI__FN xml_parse_result xml_document::load_buffer_inplace_own(void* contents, size_t size, unsigned int options, xml_encoding encoding) + { + reset(); + + return impl::load_buffer_impl(static_cast(_root), _root, contents, size, options, encoding, true, true, &_buffer); + } + + PUGI__FN void xml_document::save(xml_writer& writer, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + impl::xml_buffered_writer buffered_writer(writer, encoding); + + if ((flags & format_write_bom) && encoding != encoding_latin1) + { + // BOM always represents the codepoint U+FEFF, so just write it in native encoding + #ifdef PUGIXML_WCHAR_MODE + unsigned int bom = 0xfeff; + buffered_writer.write(static_cast(bom)); + #else + buffered_writer.write('\xef', '\xbb', '\xbf'); + #endif + } + + if (!(flags & format_no_declaration) && !impl::has_declaration(_root)) + { + buffered_writer.write_string(PUGIXML_TEXT("'); + if (!(flags & format_raw)) buffered_writer.write('\n'); + } + + impl::node_output(buffered_writer, _root, indent, flags, 0); + + buffered_writer.flush(); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN void xml_document::save(std::basic_ostream >& stream, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + xml_writer_stream writer(stream); + + save(writer, indent, flags, encoding); + } + + PUGI__FN void xml_document::save(std::basic_ostream >& stream, const char_t* indent, unsigned int flags) const + { + xml_writer_stream writer(stream); + + save(writer, indent, flags, encoding_wchar); + } +#endif + + PUGI__FN bool xml_document::save_file(const char* path_, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + using impl::auto_deleter; // MSVC7 workaround + auto_deleter file(impl::open_file(path_, (flags & format_save_file_text) ? "w" : "wb"), impl::close_file); + + return impl::save_file_impl(*this, file.data, indent, flags, encoding); + } + + PUGI__FN bool xml_document::save_file(const wchar_t* path_, const char_t* indent, unsigned int flags, xml_encoding encoding) const + { + using impl::auto_deleter; // MSVC7 workaround + auto_deleter file(impl::open_file_wide(path_, (flags & format_save_file_text) ? L"w" : L"wb"), impl::close_file); + + return impl::save_file_impl(*this, file.data, indent, flags, encoding); + } + + PUGI__FN xml_node xml_document::document_element() const + { + assert(_root); + + for (xml_node_struct* i = _root->first_child; i; i = i->next_sibling) + if (PUGI__NODETYPE(i) == node_element) + return xml_node(i); + + return xml_node(); + } + +#ifndef PUGIXML_NO_STL + PUGI__FN std::string PUGIXML_FUNCTION as_utf8(const wchar_t* str) + { + assert(str); + + return impl::as_utf8_impl(str, impl::strlength_wide(str)); + } + + PUGI__FN std::string PUGIXML_FUNCTION as_utf8(const std::basic_string& str) + { + return impl::as_utf8_impl(str.c_str(), str.size()); + } + + PUGI__FN std::basic_string PUGIXML_FUNCTION as_wide(const char* str) + { + assert(str); + + return impl::as_wide_impl(str, strlen(str)); + } + + PUGI__FN std::basic_string PUGIXML_FUNCTION as_wide(const std::string& str) + { + return impl::as_wide_impl(str.c_str(), str.size()); + } +#endif + + PUGI__FN void PUGIXML_FUNCTION set_memory_management_functions(allocation_function allocate, deallocation_function deallocate) + { + impl::xml_memory::allocate = allocate; + impl::xml_memory::deallocate = deallocate; + } + + PUGI__FN allocation_function PUGIXML_FUNCTION get_memory_allocation_function() + { + return impl::xml_memory::allocate; + } + + PUGI__FN deallocation_function PUGIXML_FUNCTION get_memory_deallocation_function() + { + return impl::xml_memory::deallocate; + } +} + +#if !defined(PUGIXML_NO_STL) && (defined(_MSC_VER) || defined(__ICC)) +namespace std +{ + // Workarounds for (non-standard) iterator category detection for older versions (MSVC7/IC8 and earlier) + PUGI__FN std::bidirectional_iterator_tag _Iter_cat(const pugi::xml_node_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::bidirectional_iterator_tag _Iter_cat(const pugi::xml_attribute_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::bidirectional_iterator_tag _Iter_cat(const pugi::xml_named_node_iterator&) + { + return std::bidirectional_iterator_tag(); + } +} +#endif + +#if !defined(PUGIXML_NO_STL) && defined(__SUNPRO_CC) +namespace std +{ + // Workarounds for (non-standard) iterator category detection + PUGI__FN std::bidirectional_iterator_tag __iterator_category(const pugi::xml_node_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::bidirectional_iterator_tag __iterator_category(const pugi::xml_attribute_iterator&) + { + return std::bidirectional_iterator_tag(); + } + + PUGI__FN std::bidirectional_iterator_tag __iterator_category(const pugi::xml_named_node_iterator&) + { + return std::bidirectional_iterator_tag(); + } +} +#endif + +#ifndef PUGIXML_NO_XPATH +// STL replacements +PUGI__NS_BEGIN + struct equal_to + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs == rhs; + } + }; + + struct not_equal_to + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs != rhs; + } + }; + + struct less + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs < rhs; + } + }; + + struct less_equal + { + template bool operator()(const T& lhs, const T& rhs) const + { + return lhs <= rhs; + } + }; + + template inline void swap(T& lhs, T& rhs) + { + T temp = lhs; + lhs = rhs; + rhs = temp; + } + + template PUGI__FN I min_element(I begin, I end, const Pred& pred) + { + I result = begin; + + for (I it = begin + 1; it != end; ++it) + if (pred(*it, *result)) + result = it; + + return result; + } + + template PUGI__FN void reverse(I begin, I end) + { + while (end - begin > 1) + swap(*begin++, *--end); + } + + template PUGI__FN I unique(I begin, I end) + { + // fast skip head + while (end - begin > 1 && *begin != *(begin + 1)) + begin++; + + if (begin == end) + return begin; + + // last written element + I write = begin++; + + // merge unique elements + while (begin != end) + { + if (*begin != *write) + *++write = *begin++; + else + begin++; + } + + // past-the-end (write points to live element) + return write + 1; + } + + template PUGI__FN void insertion_sort(T* begin, T* end, const Pred& pred) + { + if (begin == end) + return; + + for (T* it = begin + 1; it != end; ++it) + { + T val = *it; + T* hole = it; + + // move hole backwards + while (hole > begin && pred(val, *(hole - 1))) + { + *hole = *(hole - 1); + hole--; + } + + // fill hole with element + *hole = val; + } + } + + template inline I median3(I first, I middle, I last, const Pred& pred) + { + if (pred(*middle, *first)) + swap(middle, first); + if (pred(*last, *middle)) + swap(last, middle); + if (pred(*middle, *first)) + swap(middle, first); + + return middle; + } + + template PUGI__FN void partition3(T* begin, T* end, T pivot, const Pred& pred, T** out_eqbeg, T** out_eqend) + { + // invariant: array is split into 4 groups: = < ? > (each variable denotes the boundary between the groups) + T* eq = begin; + T* lt = begin; + T* gt = end; + + while (lt < gt) + { + if (pred(*lt, pivot)) + lt++; + else if (*lt == pivot) + swap(*eq++, *lt++); + else + swap(*lt, *--gt); + } + + // we now have just 4 groups: = < >; move equal elements to the middle + T* eqbeg = gt; + + for (T* it = begin; it != eq; ++it) + swap(*it, *--eqbeg); + + *out_eqbeg = eqbeg; + *out_eqend = gt; + } + + template PUGI__FN void sort(I begin, I end, const Pred& pred) + { + // sort large chunks + while (end - begin > 16) + { + // find median element + I middle = begin + (end - begin) / 2; + I median = median3(begin, middle, end - 1, pred); + + // partition in three chunks (< = >) + I eqbeg, eqend; + partition3(begin, end, *median, pred, &eqbeg, &eqend); + + // loop on larger half + if (eqbeg - begin > end - eqend) + { + sort(eqend, end, pred); + end = eqbeg; + } + else + { + sort(begin, eqbeg, pred); + begin = eqend; + } + } + + // insertion sort small chunk + insertion_sort(begin, end, pred); + } + + PUGI__FN bool hash_insert(const void** table, size_t size, const void* key) + { + assert(key); + + unsigned int h = static_cast(reinterpret_cast(key)); + + // MurmurHash3 32-bit finalizer + h ^= h >> 16; + h *= 0x85ebca6bu; + h ^= h >> 13; + h *= 0xc2b2ae35u; + h ^= h >> 16; + + size_t hashmod = size - 1; + size_t bucket = h & hashmod; + + for (size_t probe = 0; probe <= hashmod; ++probe) + { + if (table[bucket] == 0) + { + table[bucket] = key; + return true; + } + + if (table[bucket] == key) + return false; + + // hash collision, quadratic probing + bucket = (bucket + probe + 1) & hashmod; + } + + assert(false && "Hash table is full"); // unreachable + return false; + } +PUGI__NS_END + +// Allocator used for AST and evaluation stacks +PUGI__NS_BEGIN + static const size_t xpath_memory_page_size = + #ifdef PUGIXML_MEMORY_XPATH_PAGE_SIZE + PUGIXML_MEMORY_XPATH_PAGE_SIZE + #else + 4096 + #endif + ; + + static const uintptr_t xpath_memory_block_alignment = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); + + struct xpath_memory_block + { + xpath_memory_block* next; + size_t capacity; + + union + { + char data[xpath_memory_page_size]; + double alignment; + }; + }; + + struct xpath_allocator + { + xpath_memory_block* _root; + size_t _root_size; + bool* _error; + + xpath_allocator(xpath_memory_block* root, bool* error = 0): _root(root), _root_size(0), _error(error) + { + } + + void* allocate(size_t size) + { + // round size up to block alignment boundary + size = (size + xpath_memory_block_alignment - 1) & ~(xpath_memory_block_alignment - 1); + + if (_root_size + size <= _root->capacity) + { + void* buf = &_root->data[0] + _root_size; + _root_size += size; + return buf; + } + else + { + // make sure we have at least 1/4th of the page free after allocation to satisfy subsequent allocation requests + size_t block_capacity_base = sizeof(_root->data); + size_t block_capacity_req = size + block_capacity_base / 4; + size_t block_capacity = (block_capacity_base > block_capacity_req) ? block_capacity_base : block_capacity_req; + + size_t block_size = block_capacity + offsetof(xpath_memory_block, data); + + xpath_memory_block* block = static_cast(xml_memory::allocate(block_size)); + if (!block) + { + if (_error) *_error = true; + return 0; + } + + block->next = _root; + block->capacity = block_capacity; + + _root = block; + _root_size = size; + + return block->data; + } + } + + void* reallocate(void* ptr, size_t old_size, size_t new_size) + { + // round size up to block alignment boundary + old_size = (old_size + xpath_memory_block_alignment - 1) & ~(xpath_memory_block_alignment - 1); + new_size = (new_size + xpath_memory_block_alignment - 1) & ~(xpath_memory_block_alignment - 1); + + // we can only reallocate the last object + assert(ptr == 0 || static_cast(ptr) + old_size == &_root->data[0] + _root_size); + + // try to reallocate the object inplace + if (ptr && _root_size - old_size + new_size <= _root->capacity) + { + _root_size = _root_size - old_size + new_size; + return ptr; + } + + // allocate a new block + void* result = allocate(new_size); + if (!result) return 0; + + // we have a new block + if (ptr) + { + // copy old data (we only support growing) + assert(new_size >= old_size); + memcpy(result, ptr, old_size); + + // free the previous page if it had no other objects + assert(_root->data == result); + assert(_root->next); + + if (_root->next->data == ptr) + { + // deallocate the whole page, unless it was the first one + xpath_memory_block* next = _root->next->next; + + if (next) + { + xml_memory::deallocate(_root->next); + _root->next = next; + } + } + } + + return result; + } + + void revert(const xpath_allocator& state) + { + // free all new pages + xpath_memory_block* cur = _root; + + while (cur != state._root) + { + xpath_memory_block* next = cur->next; + + xml_memory::deallocate(cur); + + cur = next; + } + + // restore state + _root = state._root; + _root_size = state._root_size; + } + + void release() + { + xpath_memory_block* cur = _root; + assert(cur); + + while (cur->next) + { + xpath_memory_block* next = cur->next; + + xml_memory::deallocate(cur); + + cur = next; + } + } + }; + + struct xpath_allocator_capture + { + xpath_allocator_capture(xpath_allocator* alloc): _target(alloc), _state(*alloc) + { + } + + ~xpath_allocator_capture() + { + _target->revert(_state); + } + + xpath_allocator* _target; + xpath_allocator _state; + }; + + struct xpath_stack + { + xpath_allocator* result; + xpath_allocator* temp; + }; + + struct xpath_stack_data + { + xpath_memory_block blocks[2]; + xpath_allocator result; + xpath_allocator temp; + xpath_stack stack; + bool oom; + + xpath_stack_data(): result(blocks + 0, &oom), temp(blocks + 1, &oom), oom(false) + { + blocks[0].next = blocks[1].next = 0; + blocks[0].capacity = blocks[1].capacity = sizeof(blocks[0].data); + + stack.result = &result; + stack.temp = &temp; + } + + ~xpath_stack_data() + { + result.release(); + temp.release(); + } + }; +PUGI__NS_END + +// String class +PUGI__NS_BEGIN + class xpath_string + { + const char_t* _buffer; + bool _uses_heap; + size_t _length_heap; + + static char_t* duplicate_string(const char_t* string, size_t length, xpath_allocator* alloc) + { + char_t* result = static_cast(alloc->allocate((length + 1) * sizeof(char_t))); + if (!result) return 0; + + memcpy(result, string, length * sizeof(char_t)); + result[length] = 0; + + return result; + } + + xpath_string(const char_t* buffer, bool uses_heap_, size_t length_heap): _buffer(buffer), _uses_heap(uses_heap_), _length_heap(length_heap) + { + } + + public: + static xpath_string from_const(const char_t* str) + { + return xpath_string(str, false, 0); + } + + static xpath_string from_heap_preallocated(const char_t* begin, const char_t* end) + { + assert(begin <= end && *end == 0); + + return xpath_string(begin, true, static_cast(end - begin)); + } + + static xpath_string from_heap(const char_t* begin, const char_t* end, xpath_allocator* alloc) + { + assert(begin <= end); + + if (begin == end) + return xpath_string(); + + size_t length = static_cast(end - begin); + const char_t* data = duplicate_string(begin, length, alloc); + + return data ? xpath_string(data, true, length) : xpath_string(); + } + + xpath_string(): _buffer(PUGIXML_TEXT("")), _uses_heap(false), _length_heap(0) + { + } + + void append(const xpath_string& o, xpath_allocator* alloc) + { + // skip empty sources + if (!*o._buffer) return; + + // fast append for constant empty target and constant source + if (!*_buffer && !_uses_heap && !o._uses_heap) + { + _buffer = o._buffer; + } + else + { + // need to make heap copy + size_t target_length = length(); + size_t source_length = o.length(); + size_t result_length = target_length + source_length; + + // allocate new buffer + char_t* result = static_cast(alloc->reallocate(_uses_heap ? const_cast(_buffer) : 0, (target_length + 1) * sizeof(char_t), (result_length + 1) * sizeof(char_t))); + if (!result) return; + + // append first string to the new buffer in case there was no reallocation + if (!_uses_heap) memcpy(result, _buffer, target_length * sizeof(char_t)); + + // append second string to the new buffer + memcpy(result + target_length, o._buffer, source_length * sizeof(char_t)); + result[result_length] = 0; + + // finalize + _buffer = result; + _uses_heap = true; + _length_heap = result_length; + } + } + + const char_t* c_str() const + { + return _buffer; + } + + size_t length() const + { + return _uses_heap ? _length_heap : strlength(_buffer); + } + + char_t* data(xpath_allocator* alloc) + { + // make private heap copy + if (!_uses_heap) + { + size_t length_ = strlength(_buffer); + const char_t* data_ = duplicate_string(_buffer, length_, alloc); + + if (!data_) return 0; + + _buffer = data_; + _uses_heap = true; + _length_heap = length_; + } + + return const_cast(_buffer); + } + + bool empty() const + { + return *_buffer == 0; + } + + bool operator==(const xpath_string& o) const + { + return strequal(_buffer, o._buffer); + } + + bool operator!=(const xpath_string& o) const + { + return !strequal(_buffer, o._buffer); + } + + bool uses_heap() const + { + return _uses_heap; + } + }; +PUGI__NS_END + +PUGI__NS_BEGIN + PUGI__FN bool starts_with(const char_t* string, const char_t* pattern) + { + while (*pattern && *string == *pattern) + { + string++; + pattern++; + } + + return *pattern == 0; + } + + PUGI__FN const char_t* find_char(const char_t* s, char_t c) + { + #ifdef PUGIXML_WCHAR_MODE + return wcschr(s, c); + #else + return strchr(s, c); + #endif + } + + PUGI__FN const char_t* find_substring(const char_t* s, const char_t* p) + { + #ifdef PUGIXML_WCHAR_MODE + // MSVC6 wcsstr bug workaround (if s is empty it always returns 0) + return (*p == 0) ? s : wcsstr(s, p); + #else + return strstr(s, p); + #endif + } + + // Converts symbol to lower case, if it is an ASCII one + PUGI__FN char_t tolower_ascii(char_t ch) + { + return static_cast(ch - 'A') < 26 ? static_cast(ch | ' ') : ch; + } + + PUGI__FN xpath_string string_value(const xpath_node& na, xpath_allocator* alloc) + { + if (na.attribute()) + return xpath_string::from_const(na.attribute().value()); + else + { + xml_node n = na.node(); + + switch (n.type()) + { + case node_pcdata: + case node_cdata: + case node_comment: + case node_pi: + return xpath_string::from_const(n.value()); + + case node_document: + case node_element: + { + xpath_string result; + + // element nodes can have value if parse_embed_pcdata was used + if (n.value()[0]) + result.append(xpath_string::from_const(n.value()), alloc); + + xml_node cur = n.first_child(); + + while (cur && cur != n) + { + if (cur.type() == node_pcdata || cur.type() == node_cdata) + result.append(xpath_string::from_const(cur.value()), alloc); + + if (cur.first_child()) + cur = cur.first_child(); + else if (cur.next_sibling()) + cur = cur.next_sibling(); + else + { + while (!cur.next_sibling() && cur != n) + cur = cur.parent(); + + if (cur != n) cur = cur.next_sibling(); + } + } + + return result; + } + + default: + return xpath_string(); + } + } + } + + PUGI__FN bool node_is_before_sibling(xml_node_struct* ln, xml_node_struct* rn) + { + assert(ln->parent == rn->parent); + + // there is no common ancestor (the shared parent is null), nodes are from different documents + if (!ln->parent) return ln < rn; + + // determine sibling order + xml_node_struct* ls = ln; + xml_node_struct* rs = rn; + + while (ls && rs) + { + if (ls == rn) return true; + if (rs == ln) return false; + + ls = ls->next_sibling; + rs = rs->next_sibling; + } + + // if rn sibling chain ended ln must be before rn + return !rs; + } + + PUGI__FN bool node_is_before(xml_node_struct* ln, xml_node_struct* rn) + { + // find common ancestor at the same depth, if any + xml_node_struct* lp = ln; + xml_node_struct* rp = rn; + + while (lp && rp && lp->parent != rp->parent) + { + lp = lp->parent; + rp = rp->parent; + } + + // parents are the same! + if (lp && rp) return node_is_before_sibling(lp, rp); + + // nodes are at different depths, need to normalize heights + bool left_higher = !lp; + + while (lp) + { + lp = lp->parent; + ln = ln->parent; + } + + while (rp) + { + rp = rp->parent; + rn = rn->parent; + } + + // one node is the ancestor of the other + if (ln == rn) return left_higher; + + // find common ancestor... again + while (ln->parent != rn->parent) + { + ln = ln->parent; + rn = rn->parent; + } + + return node_is_before_sibling(ln, rn); + } + + PUGI__FN bool node_is_ancestor(xml_node_struct* parent, xml_node_struct* node) + { + while (node && node != parent) node = node->parent; + + return parent && node == parent; + } + + PUGI__FN const void* document_buffer_order(const xpath_node& xnode) + { + xml_node_struct* node = xnode.node().internal_object(); + + if (node) + { + if ((get_document(node).header & xml_memory_page_contents_shared_mask) == 0) + { + if (node->name && (node->header & impl::xml_memory_page_name_allocated_or_shared_mask) == 0) return node->name; + if (node->value && (node->header & impl::xml_memory_page_value_allocated_or_shared_mask) == 0) return node->value; + } + + return 0; + } + + xml_attribute_struct* attr = xnode.attribute().internal_object(); + + if (attr) + { + if ((get_document(attr).header & xml_memory_page_contents_shared_mask) == 0) + { + if ((attr->header & impl::xml_memory_page_name_allocated_or_shared_mask) == 0) return attr->name; + if ((attr->header & impl::xml_memory_page_value_allocated_or_shared_mask) == 0) return attr->value; + } + + return 0; + } + + return 0; + } + + struct document_order_comparator + { + bool operator()(const xpath_node& lhs, const xpath_node& rhs) const + { + // optimized document order based check + const void* lo = document_buffer_order(lhs); + const void* ro = document_buffer_order(rhs); + + if (lo && ro) return lo < ro; + + // slow comparison + xml_node ln = lhs.node(), rn = rhs.node(); + + // compare attributes + if (lhs.attribute() && rhs.attribute()) + { + // shared parent + if (lhs.parent() == rhs.parent()) + { + // determine sibling order + for (xml_attribute a = lhs.attribute(); a; a = a.next_attribute()) + if (a == rhs.attribute()) + return true; + + return false; + } + + // compare attribute parents + ln = lhs.parent(); + rn = rhs.parent(); + } + else if (lhs.attribute()) + { + // attributes go after the parent element + if (lhs.parent() == rhs.node()) return false; + + ln = lhs.parent(); + } + else if (rhs.attribute()) + { + // attributes go after the parent element + if (rhs.parent() == lhs.node()) return true; + + rn = rhs.parent(); + } + + if (ln == rn) return false; + + if (!ln || !rn) return ln < rn; + + return node_is_before(ln.internal_object(), rn.internal_object()); + } + }; + + PUGI__FN double gen_nan() + { + #if defined(__STDC_IEC_559__) || ((FLT_RADIX - 0 == 2) && (FLT_MAX_EXP - 0 == 128) && (FLT_MANT_DIG - 0 == 24)) + PUGI__STATIC_ASSERT(sizeof(float) == sizeof(uint32_t)); + typedef uint32_t UI; // BCC5 workaround + union { float f; UI i; } u; + u.i = 0x7fc00000; + return double(u.f); + #else + // fallback + const volatile double zero = 0.0; + return zero / zero; + #endif + } + + PUGI__FN bool is_nan(double value) + { + #if defined(PUGI__MSVC_CRT_VERSION) || defined(__BORLANDC__) + return !!_isnan(value); + #elif defined(fpclassify) && defined(FP_NAN) + return fpclassify(value) == FP_NAN; + #else + // fallback + const volatile double v = value; + return v != v; + #endif + } + + PUGI__FN const char_t* convert_number_to_string_special(double value) + { + #if defined(PUGI__MSVC_CRT_VERSION) || defined(__BORLANDC__) + if (_finite(value)) return (value == 0) ? PUGIXML_TEXT("0") : 0; + if (_isnan(value)) return PUGIXML_TEXT("NaN"); + return value > 0 ? PUGIXML_TEXT("Infinity") : PUGIXML_TEXT("-Infinity"); + #elif defined(fpclassify) && defined(FP_NAN) && defined(FP_INFINITE) && defined(FP_ZERO) + switch (fpclassify(value)) + { + case FP_NAN: + return PUGIXML_TEXT("NaN"); + + case FP_INFINITE: + return value > 0 ? PUGIXML_TEXT("Infinity") : PUGIXML_TEXT("-Infinity"); + + case FP_ZERO: + return PUGIXML_TEXT("0"); + + default: + return 0; + } + #else + // fallback + const volatile double v = value; + + if (v == 0) return PUGIXML_TEXT("0"); + if (v != v) return PUGIXML_TEXT("NaN"); + if (v * 2 == v) return value > 0 ? PUGIXML_TEXT("Infinity") : PUGIXML_TEXT("-Infinity"); + return 0; + #endif + } + + PUGI__FN bool convert_number_to_boolean(double value) + { + return (value != 0 && !is_nan(value)); + } + + PUGI__FN void truncate_zeros(char* begin, char* end) + { + while (begin != end && end[-1] == '0') end--; + + *end = 0; + } + + // gets mantissa digits in the form of 0.xxxxx with 0. implied and the exponent +#if defined(PUGI__MSVC_CRT_VERSION) && PUGI__MSVC_CRT_VERSION >= 1400 && !defined(_WIN32_WCE) + PUGI__FN void convert_number_to_mantissa_exponent(double value, char (&buffer)[32], char** out_mantissa, int* out_exponent) + { + // get base values + int sign, exponent; + _ecvt_s(buffer, sizeof(buffer), value, DBL_DIG + 1, &exponent, &sign); + + // truncate redundant zeros + truncate_zeros(buffer, buffer + strlen(buffer)); + + // fill results + *out_mantissa = buffer; + *out_exponent = exponent; + } +#else + PUGI__FN void convert_number_to_mantissa_exponent(double value, char (&buffer)[32], char** out_mantissa, int* out_exponent) + { + // get a scientific notation value with IEEE DBL_DIG decimals + PUGI__SNPRINTF(buffer, "%.*e", DBL_DIG, value); + + // get the exponent (possibly negative) + char* exponent_string = strchr(buffer, 'e'); + assert(exponent_string); + + int exponent = atoi(exponent_string + 1); + + // extract mantissa string: skip sign + char* mantissa = buffer[0] == '-' ? buffer + 1 : buffer; + assert(mantissa[0] != '0' && mantissa[1] == '.'); + + // divide mantissa by 10 to eliminate integer part + mantissa[1] = mantissa[0]; + mantissa++; + exponent++; + + // remove extra mantissa digits and zero-terminate mantissa + truncate_zeros(mantissa, exponent_string); + + // fill results + *out_mantissa = mantissa; + *out_exponent = exponent; + } +#endif + + PUGI__FN xpath_string convert_number_to_string(double value, xpath_allocator* alloc) + { + // try special number conversion + const char_t* special = convert_number_to_string_special(value); + if (special) return xpath_string::from_const(special); + + // get mantissa + exponent form + char mantissa_buffer[32]; + + char* mantissa; + int exponent; + convert_number_to_mantissa_exponent(value, mantissa_buffer, &mantissa, &exponent); + + // allocate a buffer of suitable length for the number + size_t result_size = strlen(mantissa_buffer) + (exponent > 0 ? exponent : -exponent) + 4; + char_t* result = static_cast(alloc->allocate(sizeof(char_t) * result_size)); + if (!result) return xpath_string(); + + // make the number! + char_t* s = result; + + // sign + if (value < 0) *s++ = '-'; + + // integer part + if (exponent <= 0) + { + *s++ = '0'; + } + else + { + while (exponent > 0) + { + assert(*mantissa == 0 || static_cast(*mantissa - '0') <= 9); + *s++ = *mantissa ? *mantissa++ : '0'; + exponent--; + } + } + + // fractional part + if (*mantissa) + { + // decimal point + *s++ = '.'; + + // extra zeroes from negative exponent + while (exponent < 0) + { + *s++ = '0'; + exponent++; + } + + // extra mantissa digits + while (*mantissa) + { + assert(static_cast(*mantissa - '0') <= 9); + *s++ = *mantissa++; + } + } + + // zero-terminate + assert(s < result + result_size); + *s = 0; + + return xpath_string::from_heap_preallocated(result, s); + } + + PUGI__FN bool check_string_to_number_format(const char_t* string) + { + // parse leading whitespace + while (PUGI__IS_CHARTYPE(*string, ct_space)) ++string; + + // parse sign + if (*string == '-') ++string; + + if (!*string) return false; + + // if there is no integer part, there should be a decimal part with at least one digit + if (!PUGI__IS_CHARTYPEX(string[0], ctx_digit) && (string[0] != '.' || !PUGI__IS_CHARTYPEX(string[1], ctx_digit))) return false; + + // parse integer part + while (PUGI__IS_CHARTYPEX(*string, ctx_digit)) ++string; + + // parse decimal part + if (*string == '.') + { + ++string; + + while (PUGI__IS_CHARTYPEX(*string, ctx_digit)) ++string; + } + + // parse trailing whitespace + while (PUGI__IS_CHARTYPE(*string, ct_space)) ++string; + + return *string == 0; + } + + PUGI__FN double convert_string_to_number(const char_t* string) + { + // check string format + if (!check_string_to_number_format(string)) return gen_nan(); + + // parse string + #ifdef PUGIXML_WCHAR_MODE + return wcstod(string, 0); + #else + return strtod(string, 0); + #endif + } + + PUGI__FN bool convert_string_to_number_scratch(char_t (&buffer)[32], const char_t* begin, const char_t* end, double* out_result) + { + size_t length = static_cast(end - begin); + char_t* scratch = buffer; + + if (length >= sizeof(buffer) / sizeof(buffer[0])) + { + // need to make dummy on-heap copy + scratch = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!scratch) return false; + } + + // copy string to zero-terminated buffer and perform conversion + memcpy(scratch, begin, length * sizeof(char_t)); + scratch[length] = 0; + + *out_result = convert_string_to_number(scratch); + + // free dummy buffer + if (scratch != buffer) xml_memory::deallocate(scratch); + + return true; + } + + PUGI__FN double round_nearest(double value) + { + return floor(value + 0.5); + } + + PUGI__FN double round_nearest_nzero(double value) + { + // same as round_nearest, but returns -0 for [-0.5, -0] + // ceil is used to differentiate between +0 and -0 (we return -0 for [-0.5, -0] and +0 for +0) + return (value >= -0.5 && value <= 0) ? ceil(value) : floor(value + 0.5); + } + + PUGI__FN const char_t* qualified_name(const xpath_node& node) + { + return node.attribute() ? node.attribute().name() : node.node().name(); + } + + PUGI__FN const char_t* local_name(const xpath_node& node) + { + const char_t* name = qualified_name(node); + const char_t* p = find_char(name, ':'); + + return p ? p + 1 : name; + } + + struct namespace_uri_predicate + { + const char_t* prefix; + size_t prefix_length; + + namespace_uri_predicate(const char_t* name) + { + const char_t* pos = find_char(name, ':'); + + prefix = pos ? name : 0; + prefix_length = pos ? static_cast(pos - name) : 0; + } + + bool operator()(xml_attribute a) const + { + const char_t* name = a.name(); + + if (!starts_with(name, PUGIXML_TEXT("xmlns"))) return false; + + return prefix ? name[5] == ':' && strequalrange(name + 6, prefix, prefix_length) : name[5] == 0; + } + }; + + PUGI__FN const char_t* namespace_uri(xml_node node) + { + namespace_uri_predicate pred = node.name(); + + xml_node p = node; + + while (p) + { + xml_attribute a = p.find_attribute(pred); + + if (a) return a.value(); + + p = p.parent(); + } + + return PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* namespace_uri(xml_attribute attr, xml_node parent) + { + namespace_uri_predicate pred = attr.name(); + + // Default namespace does not apply to attributes + if (!pred.prefix) return PUGIXML_TEXT(""); + + xml_node p = parent; + + while (p) + { + xml_attribute a = p.find_attribute(pred); + + if (a) return a.value(); + + p = p.parent(); + } + + return PUGIXML_TEXT(""); + } + + PUGI__FN const char_t* namespace_uri(const xpath_node& node) + { + return node.attribute() ? namespace_uri(node.attribute(), node.parent()) : namespace_uri(node.node()); + } + + PUGI__FN char_t* normalize_space(char_t* buffer) + { + char_t* write = buffer; + + for (char_t* it = buffer; *it; ) + { + char_t ch = *it++; + + if (PUGI__IS_CHARTYPE(ch, ct_space)) + { + // replace whitespace sequence with single space + while (PUGI__IS_CHARTYPE(*it, ct_space)) it++; + + // avoid leading spaces + if (write != buffer) *write++ = ' '; + } + else *write++ = ch; + } + + // remove trailing space + if (write != buffer && PUGI__IS_CHARTYPE(write[-1], ct_space)) write--; + + // zero-terminate + *write = 0; + + return write; + } + + PUGI__FN char_t* translate(char_t* buffer, const char_t* from, const char_t* to, size_t to_length) + { + char_t* write = buffer; + + while (*buffer) + { + PUGI__DMC_VOLATILE char_t ch = *buffer++; + + const char_t* pos = find_char(from, ch); + + if (!pos) + *write++ = ch; // do not process + else if (static_cast(pos - from) < to_length) + *write++ = to[pos - from]; // replace + } + + // zero-terminate + *write = 0; + + return write; + } + + PUGI__FN unsigned char* translate_table_generate(xpath_allocator* alloc, const char_t* from, const char_t* to) + { + unsigned char table[128] = {0}; + + while (*from) + { + unsigned int fc = static_cast(*from); + unsigned int tc = static_cast(*to); + + if (fc >= 128 || tc >= 128) + return 0; + + // code=128 means "skip character" + if (!table[fc]) + table[fc] = static_cast(tc ? tc : 128); + + from++; + if (tc) to++; + } + + for (int i = 0; i < 128; ++i) + if (!table[i]) + table[i] = static_cast(i); + + void* result = alloc->allocate(sizeof(table)); + if (!result) return 0; + + memcpy(result, table, sizeof(table)); + + return static_cast(result); + } + + PUGI__FN char_t* translate_table(char_t* buffer, const unsigned char* table) + { + char_t* write = buffer; + + while (*buffer) + { + char_t ch = *buffer++; + unsigned int index = static_cast(ch); + + if (index < 128) + { + unsigned char code = table[index]; + + // code=128 means "skip character" (table size is 128 so 128 can be a special value) + // this code skips these characters without extra branches + *write = static_cast(code); + write += 1 - (code >> 7); + } + else + { + *write++ = ch; + } + } + + // zero-terminate + *write = 0; + + return write; + } + + inline bool is_xpath_attribute(const char_t* name) + { + return !(starts_with(name, PUGIXML_TEXT("xmlns")) && (name[5] == 0 || name[5] == ':')); + } + + struct xpath_variable_boolean: xpath_variable + { + xpath_variable_boolean(): xpath_variable(xpath_type_boolean), value(false) + { + } + + bool value; + char_t name[1]; + }; + + struct xpath_variable_number: xpath_variable + { + xpath_variable_number(): xpath_variable(xpath_type_number), value(0) + { + } + + double value; + char_t name[1]; + }; + + struct xpath_variable_string: xpath_variable + { + xpath_variable_string(): xpath_variable(xpath_type_string), value(0) + { + } + + ~xpath_variable_string() + { + if (value) xml_memory::deallocate(value); + } + + char_t* value; + char_t name[1]; + }; + + struct xpath_variable_node_set: xpath_variable + { + xpath_variable_node_set(): xpath_variable(xpath_type_node_set) + { + } + + xpath_node_set value; + char_t name[1]; + }; + + static const xpath_node_set dummy_node_set; + + PUGI__FN PUGI__UNSIGNED_OVERFLOW unsigned int hash_string(const char_t* str) + { + // Jenkins one-at-a-time hash (http://en.wikipedia.org/wiki/Jenkins_hash_function#one-at-a-time) + unsigned int result = 0; + + while (*str) + { + result += static_cast(*str++); + result += result << 10; + result ^= result >> 6; + } + + result += result << 3; + result ^= result >> 11; + result += result << 15; + + return result; + } + + template PUGI__FN T* new_xpath_variable(const char_t* name) + { + size_t length = strlength(name); + if (length == 0) return 0; // empty variable names are invalid + + // $$ we can't use offsetof(T, name) because T is non-POD, so we just allocate additional length characters + void* memory = xml_memory::allocate(sizeof(T) + length * sizeof(char_t)); + if (!memory) return 0; + + T* result = new (memory) T(); + + memcpy(result->name, name, (length + 1) * sizeof(char_t)); + + return result; + } + + PUGI__FN xpath_variable* new_xpath_variable(xpath_value_type type, const char_t* name) + { + switch (type) + { + case xpath_type_node_set: + return new_xpath_variable(name); + + case xpath_type_number: + return new_xpath_variable(name); + + case xpath_type_string: + return new_xpath_variable(name); + + case xpath_type_boolean: + return new_xpath_variable(name); + + default: + return 0; + } + } + + template PUGI__FN void delete_xpath_variable(T* var) + { + var->~T(); + xml_memory::deallocate(var); + } + + PUGI__FN void delete_xpath_variable(xpath_value_type type, xpath_variable* var) + { + switch (type) + { + case xpath_type_node_set: + delete_xpath_variable(static_cast(var)); + break; + + case xpath_type_number: + delete_xpath_variable(static_cast(var)); + break; + + case xpath_type_string: + delete_xpath_variable(static_cast(var)); + break; + + case xpath_type_boolean: + delete_xpath_variable(static_cast(var)); + break; + + default: + assert(false && "Invalid variable type"); // unreachable + } + } + + PUGI__FN bool copy_xpath_variable(xpath_variable* lhs, const xpath_variable* rhs) + { + switch (rhs->type()) + { + case xpath_type_node_set: + return lhs->set(static_cast(rhs)->value); + + case xpath_type_number: + return lhs->set(static_cast(rhs)->value); + + case xpath_type_string: + return lhs->set(static_cast(rhs)->value); + + case xpath_type_boolean: + return lhs->set(static_cast(rhs)->value); + + default: + assert(false && "Invalid variable type"); // unreachable + return false; + } + } + + PUGI__FN bool get_variable_scratch(char_t (&buffer)[32], xpath_variable_set* set, const char_t* begin, const char_t* end, xpath_variable** out_result) + { + size_t length = static_cast(end - begin); + char_t* scratch = buffer; + + if (length >= sizeof(buffer) / sizeof(buffer[0])) + { + // need to make dummy on-heap copy + scratch = static_cast(xml_memory::allocate((length + 1) * sizeof(char_t))); + if (!scratch) return false; + } + + // copy string to zero-terminated buffer and perform lookup + memcpy(scratch, begin, length * sizeof(char_t)); + scratch[length] = 0; + + *out_result = set->get(scratch); + + // free dummy buffer + if (scratch != buffer) xml_memory::deallocate(scratch); + + return true; + } +PUGI__NS_END + +// Internal node set class +PUGI__NS_BEGIN + PUGI__FN xpath_node_set::type_t xpath_get_order(const xpath_node* begin, const xpath_node* end) + { + if (end - begin < 2) + return xpath_node_set::type_sorted; + + document_order_comparator cmp; + + bool first = cmp(begin[0], begin[1]); + + for (const xpath_node* it = begin + 1; it + 1 < end; ++it) + if (cmp(it[0], it[1]) != first) + return xpath_node_set::type_unsorted; + + return first ? xpath_node_set::type_sorted : xpath_node_set::type_sorted_reverse; + } + + PUGI__FN xpath_node_set::type_t xpath_sort(xpath_node* begin, xpath_node* end, xpath_node_set::type_t type, bool rev) + { + xpath_node_set::type_t order = rev ? xpath_node_set::type_sorted_reverse : xpath_node_set::type_sorted; + + if (type == xpath_node_set::type_unsorted) + { + xpath_node_set::type_t sorted = xpath_get_order(begin, end); + + if (sorted == xpath_node_set::type_unsorted) + { + sort(begin, end, document_order_comparator()); + + type = xpath_node_set::type_sorted; + } + else + type = sorted; + } + + if (type != order) reverse(begin, end); + + return order; + } + + PUGI__FN xpath_node xpath_first(const xpath_node* begin, const xpath_node* end, xpath_node_set::type_t type) + { + if (begin == end) return xpath_node(); + + switch (type) + { + case xpath_node_set::type_sorted: + return *begin; + + case xpath_node_set::type_sorted_reverse: + return *(end - 1); + + case xpath_node_set::type_unsorted: + return *min_element(begin, end, document_order_comparator()); + + default: + assert(false && "Invalid node set type"); // unreachable + return xpath_node(); + } + } + + class xpath_node_set_raw + { + xpath_node_set::type_t _type; + + xpath_node* _begin; + xpath_node* _end; + xpath_node* _eos; + + public: + xpath_node_set_raw(): _type(xpath_node_set::type_unsorted), _begin(0), _end(0), _eos(0) + { + } + + xpath_node* begin() const + { + return _begin; + } + + xpath_node* end() const + { + return _end; + } + + bool empty() const + { + return _begin == _end; + } + + size_t size() const + { + return static_cast(_end - _begin); + } + + xpath_node first() const + { + return xpath_first(_begin, _end, _type); + } + + void push_back_grow(const xpath_node& node, xpath_allocator* alloc); + + void push_back(const xpath_node& node, xpath_allocator* alloc) + { + if (_end != _eos) + *_end++ = node; + else + push_back_grow(node, alloc); + } + + void append(const xpath_node* begin_, const xpath_node* end_, xpath_allocator* alloc) + { + if (begin_ == end_) return; + + size_t size_ = static_cast(_end - _begin); + size_t capacity = static_cast(_eos - _begin); + size_t count = static_cast(end_ - begin_); + + if (size_ + count > capacity) + { + // reallocate the old array or allocate a new one + xpath_node* data = static_cast(alloc->reallocate(_begin, capacity * sizeof(xpath_node), (size_ + count) * sizeof(xpath_node))); + if (!data) return; + + // finalize + _begin = data; + _end = data + size_; + _eos = data + size_ + count; + } + + memcpy(_end, begin_, count * sizeof(xpath_node)); + _end += count; + } + + void sort_do() + { + _type = xpath_sort(_begin, _end, _type, false); + } + + void truncate(xpath_node* pos) + { + assert(_begin <= pos && pos <= _end); + + _end = pos; + } + + void remove_duplicates(xpath_allocator* alloc) + { + if (_type == xpath_node_set::type_unsorted && _end - _begin > 2) + { + xpath_allocator_capture cr(alloc); + + size_t size_ = static_cast(_end - _begin); + + size_t hash_size = 1; + while (hash_size < size_ + size_ / 2) hash_size *= 2; + + const void** hash_data = static_cast(alloc->allocate(hash_size * sizeof(void**))); + if (!hash_data) return; + + memset(hash_data, 0, hash_size * sizeof(const void**)); + + xpath_node* write = _begin; + + for (xpath_node* it = _begin; it != _end; ++it) + { + const void* attr = it->attribute().internal_object(); + const void* node = it->node().internal_object(); + const void* key = attr ? attr : node; + + if (key && hash_insert(hash_data, hash_size, key)) + { + *write++ = *it; + } + } + + _end = write; + } + else + { + _end = unique(_begin, _end); + } + } + + xpath_node_set::type_t type() const + { + return _type; + } + + void set_type(xpath_node_set::type_t value) + { + _type = value; + } + }; + + PUGI__FN_NO_INLINE void xpath_node_set_raw::push_back_grow(const xpath_node& node, xpath_allocator* alloc) + { + size_t capacity = static_cast(_eos - _begin); + + // get new capacity (1.5x rule) + size_t new_capacity = capacity + capacity / 2 + 1; + + // reallocate the old array or allocate a new one + xpath_node* data = static_cast(alloc->reallocate(_begin, capacity * sizeof(xpath_node), new_capacity * sizeof(xpath_node))); + if (!data) return; + + // finalize + _begin = data; + _end = data + capacity; + _eos = data + new_capacity; + + // push + *_end++ = node; + } +PUGI__NS_END + +PUGI__NS_BEGIN + struct xpath_context + { + xpath_node n; + size_t position, size; + + xpath_context(const xpath_node& n_, size_t position_, size_t size_): n(n_), position(position_), size(size_) + { + } + }; + + enum lexeme_t + { + lex_none = 0, + lex_equal, + lex_not_equal, + lex_less, + lex_greater, + lex_less_or_equal, + lex_greater_or_equal, + lex_plus, + lex_minus, + lex_multiply, + lex_union, + lex_var_ref, + lex_open_brace, + lex_close_brace, + lex_quoted_string, + lex_number, + lex_slash, + lex_double_slash, + lex_open_square_brace, + lex_close_square_brace, + lex_string, + lex_comma, + lex_axis_attribute, + lex_dot, + lex_double_dot, + lex_double_colon, + lex_eof + }; + + struct xpath_lexer_string + { + const char_t* begin; + const char_t* end; + + xpath_lexer_string(): begin(0), end(0) + { + } + + bool operator==(const char_t* other) const + { + size_t length = static_cast(end - begin); + + return strequalrange(other, begin, length); + } + }; + + class xpath_lexer + { + const char_t* _cur; + const char_t* _cur_lexeme_pos; + xpath_lexer_string _cur_lexeme_contents; + + lexeme_t _cur_lexeme; + + public: + explicit xpath_lexer(const char_t* query): _cur(query) + { + next(); + } + + const char_t* state() const + { + return _cur; + } + + void next() + { + const char_t* cur = _cur; + + while (PUGI__IS_CHARTYPE(*cur, ct_space)) ++cur; + + // save lexeme position for error reporting + _cur_lexeme_pos = cur; + + switch (*cur) + { + case 0: + _cur_lexeme = lex_eof; + break; + + case '>': + if (*(cur+1) == '=') + { + cur += 2; + _cur_lexeme = lex_greater_or_equal; + } + else + { + cur += 1; + _cur_lexeme = lex_greater; + } + break; + + case '<': + if (*(cur+1) == '=') + { + cur += 2; + _cur_lexeme = lex_less_or_equal; + } + else + { + cur += 1; + _cur_lexeme = lex_less; + } + break; + + case '!': + if (*(cur+1) == '=') + { + cur += 2; + _cur_lexeme = lex_not_equal; + } + else + { + _cur_lexeme = lex_none; + } + break; + + case '=': + cur += 1; + _cur_lexeme = lex_equal; + + break; + + case '+': + cur += 1; + _cur_lexeme = lex_plus; + + break; + + case '-': + cur += 1; + _cur_lexeme = lex_minus; + + break; + + case '*': + cur += 1; + _cur_lexeme = lex_multiply; + + break; + + case '|': + cur += 1; + _cur_lexeme = lex_union; + + break; + + case '$': + cur += 1; + + if (PUGI__IS_CHARTYPEX(*cur, ctx_start_symbol)) + { + _cur_lexeme_contents.begin = cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + + if (cur[0] == ':' && PUGI__IS_CHARTYPEX(cur[1], ctx_symbol)) // qname + { + cur++; // : + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + } + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_var_ref; + } + else + { + _cur_lexeme = lex_none; + } + + break; + + case '(': + cur += 1; + _cur_lexeme = lex_open_brace; + + break; + + case ')': + cur += 1; + _cur_lexeme = lex_close_brace; + + break; + + case '[': + cur += 1; + _cur_lexeme = lex_open_square_brace; + + break; + + case ']': + cur += 1; + _cur_lexeme = lex_close_square_brace; + + break; + + case ',': + cur += 1; + _cur_lexeme = lex_comma; + + break; + + case '/': + if (*(cur+1) == '/') + { + cur += 2; + _cur_lexeme = lex_double_slash; + } + else + { + cur += 1; + _cur_lexeme = lex_slash; + } + break; + + case '.': + if (*(cur+1) == '.') + { + cur += 2; + _cur_lexeme = lex_double_dot; + } + else if (PUGI__IS_CHARTYPEX(*(cur+1), ctx_digit)) + { + _cur_lexeme_contents.begin = cur; // . + + ++cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) cur++; + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_number; + } + else + { + cur += 1; + _cur_lexeme = lex_dot; + } + break; + + case '@': + cur += 1; + _cur_lexeme = lex_axis_attribute; + + break; + + case '"': + case '\'': + { + char_t terminator = *cur; + + ++cur; + + _cur_lexeme_contents.begin = cur; + while (*cur && *cur != terminator) cur++; + _cur_lexeme_contents.end = cur; + + if (!*cur) + _cur_lexeme = lex_none; + else + { + cur += 1; + _cur_lexeme = lex_quoted_string; + } + + break; + } + + case ':': + if (*(cur+1) == ':') + { + cur += 2; + _cur_lexeme = lex_double_colon; + } + else + { + _cur_lexeme = lex_none; + } + break; + + default: + if (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) + { + _cur_lexeme_contents.begin = cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) cur++; + + if (*cur == '.') + { + cur++; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_digit)) cur++; + } + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_number; + } + else if (PUGI__IS_CHARTYPEX(*cur, ctx_start_symbol)) + { + _cur_lexeme_contents.begin = cur; + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + + if (cur[0] == ':') + { + if (cur[1] == '*') // namespace test ncname:* + { + cur += 2; // :* + } + else if (PUGI__IS_CHARTYPEX(cur[1], ctx_symbol)) // namespace test qname + { + cur++; // : + + while (PUGI__IS_CHARTYPEX(*cur, ctx_symbol)) cur++; + } + } + + _cur_lexeme_contents.end = cur; + + _cur_lexeme = lex_string; + } + else + { + _cur_lexeme = lex_none; + } + } + + _cur = cur; + } + + lexeme_t current() const + { + return _cur_lexeme; + } + + const char_t* current_pos() const + { + return _cur_lexeme_pos; + } + + const xpath_lexer_string& contents() const + { + assert(_cur_lexeme == lex_var_ref || _cur_lexeme == lex_number || _cur_lexeme == lex_string || _cur_lexeme == lex_quoted_string); + + return _cur_lexeme_contents; + } + }; + + enum ast_type_t + { + ast_unknown, + ast_op_or, // left or right + ast_op_and, // left and right + ast_op_equal, // left = right + ast_op_not_equal, // left != right + ast_op_less, // left < right + ast_op_greater, // left > right + ast_op_less_or_equal, // left <= right + ast_op_greater_or_equal, // left >= right + ast_op_add, // left + right + ast_op_subtract, // left - right + ast_op_multiply, // left * right + ast_op_divide, // left / right + ast_op_mod, // left % right + ast_op_negate, // left - right + ast_op_union, // left | right + ast_predicate, // apply predicate to set; next points to next predicate + ast_filter, // select * from left where right + ast_string_constant, // string constant + ast_number_constant, // number constant + ast_variable, // variable + ast_func_last, // last() + ast_func_position, // position() + ast_func_count, // count(left) + ast_func_id, // id(left) + ast_func_local_name_0, // local-name() + ast_func_local_name_1, // local-name(left) + ast_func_namespace_uri_0, // namespace-uri() + ast_func_namespace_uri_1, // namespace-uri(left) + ast_func_name_0, // name() + ast_func_name_1, // name(left) + ast_func_string_0, // string() + ast_func_string_1, // string(left) + ast_func_concat, // concat(left, right, siblings) + ast_func_starts_with, // starts_with(left, right) + ast_func_contains, // contains(left, right) + ast_func_substring_before, // substring-before(left, right) + ast_func_substring_after, // substring-after(left, right) + ast_func_substring_2, // substring(left, right) + ast_func_substring_3, // substring(left, right, third) + ast_func_string_length_0, // string-length() + ast_func_string_length_1, // string-length(left) + ast_func_normalize_space_0, // normalize-space() + ast_func_normalize_space_1, // normalize-space(left) + ast_func_translate, // translate(left, right, third) + ast_func_boolean, // boolean(left) + ast_func_not, // not(left) + ast_func_true, // true() + ast_func_false, // false() + ast_func_lang, // lang(left) + ast_func_number_0, // number() + ast_func_number_1, // number(left) + ast_func_sum, // sum(left) + ast_func_floor, // floor(left) + ast_func_ceiling, // ceiling(left) + ast_func_round, // round(left) + ast_step, // process set left with step + ast_step_root, // select root node + + ast_opt_translate_table, // translate(left, right, third) where right/third are constants + ast_opt_compare_attribute // @name = 'string' + }; + + enum axis_t + { + axis_ancestor, + axis_ancestor_or_self, + axis_attribute, + axis_child, + axis_descendant, + axis_descendant_or_self, + axis_following, + axis_following_sibling, + axis_namespace, + axis_parent, + axis_preceding, + axis_preceding_sibling, + axis_self + }; + + enum nodetest_t + { + nodetest_none, + nodetest_name, + nodetest_type_node, + nodetest_type_comment, + nodetest_type_pi, + nodetest_type_text, + nodetest_pi, + nodetest_all, + nodetest_all_in_namespace + }; + + enum predicate_t + { + predicate_default, + predicate_posinv, + predicate_constant, + predicate_constant_one + }; + + enum nodeset_eval_t + { + nodeset_eval_all, + nodeset_eval_any, + nodeset_eval_first + }; + + template struct axis_to_type + { + static const axis_t axis; + }; + + template const axis_t axis_to_type::axis = N; + + class xpath_ast_node + { + private: + // node type + char _type; + char _rettype; + + // for ast_step + char _axis; + + // for ast_step/ast_predicate/ast_filter + char _test; + + // tree node structure + xpath_ast_node* _left; + xpath_ast_node* _right; + xpath_ast_node* _next; + + union + { + // value for ast_string_constant + const char_t* string; + // value for ast_number_constant + double number; + // variable for ast_variable + xpath_variable* variable; + // node test for ast_step (node name/namespace/node type/pi target) + const char_t* nodetest; + // table for ast_opt_translate_table + const unsigned char* table; + } _data; + + xpath_ast_node(const xpath_ast_node&); + xpath_ast_node& operator=(const xpath_ast_node&); + + template static bool compare_eq(xpath_ast_node* lhs, xpath_ast_node* rhs, const xpath_context& c, const xpath_stack& stack, const Comp& comp) + { + xpath_value_type lt = lhs->rettype(), rt = rhs->rettype(); + + if (lt != xpath_type_node_set && rt != xpath_type_node_set) + { + if (lt == xpath_type_boolean || rt == xpath_type_boolean) + return comp(lhs->eval_boolean(c, stack), rhs->eval_boolean(c, stack)); + else if (lt == xpath_type_number || rt == xpath_type_number) + return comp(lhs->eval_number(c, stack), rhs->eval_number(c, stack)); + else if (lt == xpath_type_string || rt == xpath_type_string) + { + xpath_allocator_capture cr(stack.result); + + xpath_string ls = lhs->eval_string(c, stack); + xpath_string rs = rhs->eval_string(c, stack); + + return comp(ls, rs); + } + } + else if (lt == xpath_type_node_set && rt == xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ls = lhs->eval_node_set(c, stack, nodeset_eval_all); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack, nodeset_eval_all); + + for (const xpath_node* li = ls.begin(); li != ls.end(); ++li) + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(string_value(*li, stack.result), string_value(*ri, stack.result))) + return true; + } + + return false; + } + else + { + if (lt == xpath_type_node_set) + { + swap(lhs, rhs); + swap(lt, rt); + } + + if (lt == xpath_type_boolean) + return comp(lhs->eval_boolean(c, stack), rhs->eval_boolean(c, stack)); + else if (lt == xpath_type_number) + { + xpath_allocator_capture cr(stack.result); + + double l = lhs->eval_number(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack, nodeset_eval_all); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(l, convert_string_to_number(string_value(*ri, stack.result).c_str()))) + return true; + } + + return false; + } + else if (lt == xpath_type_string) + { + xpath_allocator_capture cr(stack.result); + + xpath_string l = lhs->eval_string(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack, nodeset_eval_all); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(l, string_value(*ri, stack.result))) + return true; + } + + return false; + } + } + + assert(false && "Wrong types"); // unreachable + return false; + } + + static bool eval_once(xpath_node_set::type_t type, nodeset_eval_t eval) + { + return type == xpath_node_set::type_sorted ? eval != nodeset_eval_all : eval == nodeset_eval_any; + } + + template static bool compare_rel(xpath_ast_node* lhs, xpath_ast_node* rhs, const xpath_context& c, const xpath_stack& stack, const Comp& comp) + { + xpath_value_type lt = lhs->rettype(), rt = rhs->rettype(); + + if (lt != xpath_type_node_set && rt != xpath_type_node_set) + return comp(lhs->eval_number(c, stack), rhs->eval_number(c, stack)); + else if (lt == xpath_type_node_set && rt == xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ls = lhs->eval_node_set(c, stack, nodeset_eval_all); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack, nodeset_eval_all); + + for (const xpath_node* li = ls.begin(); li != ls.end(); ++li) + { + xpath_allocator_capture cri(stack.result); + + double l = convert_string_to_number(string_value(*li, stack.result).c_str()); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture crii(stack.result); + + if (comp(l, convert_string_to_number(string_value(*ri, stack.result).c_str()))) + return true; + } + } + + return false; + } + else if (lt != xpath_type_node_set && rt == xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + double l = lhs->eval_number(c, stack); + xpath_node_set_raw rs = rhs->eval_node_set(c, stack, nodeset_eval_all); + + for (const xpath_node* ri = rs.begin(); ri != rs.end(); ++ri) + { + xpath_allocator_capture cri(stack.result); + + if (comp(l, convert_string_to_number(string_value(*ri, stack.result).c_str()))) + return true; + } + + return false; + } + else if (lt == xpath_type_node_set && rt != xpath_type_node_set) + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ls = lhs->eval_node_set(c, stack, nodeset_eval_all); + double r = rhs->eval_number(c, stack); + + for (const xpath_node* li = ls.begin(); li != ls.end(); ++li) + { + xpath_allocator_capture cri(stack.result); + + if (comp(convert_string_to_number(string_value(*li, stack.result).c_str()), r)) + return true; + } + + return false; + } + else + { + assert(false && "Wrong types"); // unreachable + return false; + } + } + + static void apply_predicate_boolean(xpath_node_set_raw& ns, size_t first, xpath_ast_node* expr, const xpath_stack& stack, bool once) + { + assert(ns.size() >= first); + assert(expr->rettype() != xpath_type_number); + + size_t i = 1; + size_t size = ns.size() - first; + + xpath_node* last = ns.begin() + first; + + // remove_if... or well, sort of + for (xpath_node* it = last; it != ns.end(); ++it, ++i) + { + xpath_context c(*it, i, size); + + if (expr->eval_boolean(c, stack)) + { + *last++ = *it; + + if (once) break; + } + } + + ns.truncate(last); + } + + static void apply_predicate_number(xpath_node_set_raw& ns, size_t first, xpath_ast_node* expr, const xpath_stack& stack, bool once) + { + assert(ns.size() >= first); + assert(expr->rettype() == xpath_type_number); + + size_t i = 1; + size_t size = ns.size() - first; + + xpath_node* last = ns.begin() + first; + + // remove_if... or well, sort of + for (xpath_node* it = last; it != ns.end(); ++it, ++i) + { + xpath_context c(*it, i, size); + + if (expr->eval_number(c, stack) == static_cast(i)) + { + *last++ = *it; + + if (once) break; + } + } + + ns.truncate(last); + } + + static void apply_predicate_number_const(xpath_node_set_raw& ns, size_t first, xpath_ast_node* expr, const xpath_stack& stack) + { + assert(ns.size() >= first); + assert(expr->rettype() == xpath_type_number); + + size_t size = ns.size() - first; + + xpath_node* last = ns.begin() + first; + + xpath_context c(xpath_node(), 1, size); + + double er = expr->eval_number(c, stack); + + if (er >= 1.0 && er <= static_cast(size)) + { + size_t eri = static_cast(er); + + if (er == static_cast(eri)) + { + xpath_node r = last[eri - 1]; + + *last++ = r; + } + } + + ns.truncate(last); + } + + void apply_predicate(xpath_node_set_raw& ns, size_t first, const xpath_stack& stack, bool once) + { + if (ns.size() == first) return; + + assert(_type == ast_filter || _type == ast_predicate); + + if (_test == predicate_constant || _test == predicate_constant_one) + apply_predicate_number_const(ns, first, _right, stack); + else if (_right->rettype() == xpath_type_number) + apply_predicate_number(ns, first, _right, stack, once); + else + apply_predicate_boolean(ns, first, _right, stack, once); + } + + void apply_predicates(xpath_node_set_raw& ns, size_t first, const xpath_stack& stack, nodeset_eval_t eval) + { + if (ns.size() == first) return; + + bool last_once = eval_once(ns.type(), eval); + + for (xpath_ast_node* pred = _right; pred; pred = pred->_next) + pred->apply_predicate(ns, first, stack, !pred->_next && last_once); + } + + bool step_push(xpath_node_set_raw& ns, xml_attribute_struct* a, xml_node_struct* parent, xpath_allocator* alloc) + { + assert(a); + + const char_t* name = a->name ? a->name + 0 : PUGIXML_TEXT(""); + + switch (_test) + { + case nodetest_name: + if (strequal(name, _data.nodetest) && is_xpath_attribute(name)) + { + ns.push_back(xpath_node(xml_attribute(a), xml_node(parent)), alloc); + return true; + } + break; + + case nodetest_type_node: + case nodetest_all: + if (is_xpath_attribute(name)) + { + ns.push_back(xpath_node(xml_attribute(a), xml_node(parent)), alloc); + return true; + } + break; + + case nodetest_all_in_namespace: + if (starts_with(name, _data.nodetest) && is_xpath_attribute(name)) + { + ns.push_back(xpath_node(xml_attribute(a), xml_node(parent)), alloc); + return true; + } + break; + + default: + ; + } + + return false; + } + + bool step_push(xpath_node_set_raw& ns, xml_node_struct* n, xpath_allocator* alloc) + { + assert(n); + + xml_node_type type = PUGI__NODETYPE(n); + + switch (_test) + { + case nodetest_name: + if (type == node_element && n->name && strequal(n->name, _data.nodetest)) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + case nodetest_type_node: + ns.push_back(xml_node(n), alloc); + return true; + + case nodetest_type_comment: + if (type == node_comment) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + case nodetest_type_text: + if (type == node_pcdata || type == node_cdata) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + case nodetest_type_pi: + if (type == node_pi) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + case nodetest_pi: + if (type == node_pi && n->name && strequal(n->name, _data.nodetest)) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + case nodetest_all: + if (type == node_element) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + case nodetest_all_in_namespace: + if (type == node_element && n->name && starts_with(n->name, _data.nodetest)) + { + ns.push_back(xml_node(n), alloc); + return true; + } + break; + + default: + assert(false && "Unknown axis"); // unreachable + } + + return false; + } + + template void step_fill(xpath_node_set_raw& ns, xml_node_struct* n, xpath_allocator* alloc, bool once, T) + { + const axis_t axis = T::axis; + + switch (axis) + { + case axis_attribute: + { + for (xml_attribute_struct* a = n->first_attribute; a; a = a->next_attribute) + if (step_push(ns, a, n, alloc) & once) + return; + + break; + } + + case axis_child: + { + for (xml_node_struct* c = n->first_child; c; c = c->next_sibling) + if (step_push(ns, c, alloc) & once) + return; + + break; + } + + case axis_descendant: + case axis_descendant_or_self: + { + if (axis == axis_descendant_or_self) + if (step_push(ns, n, alloc) & once) + return; + + xml_node_struct* cur = n->first_child; + + while (cur) + { + if (step_push(ns, cur, alloc) & once) + return; + + if (cur->first_child) + cur = cur->first_child; + else + { + while (!cur->next_sibling) + { + cur = cur->parent; + + if (cur == n) return; + } + + cur = cur->next_sibling; + } + } + + break; + } + + case axis_following_sibling: + { + for (xml_node_struct* c = n->next_sibling; c; c = c->next_sibling) + if (step_push(ns, c, alloc) & once) + return; + + break; + } + + case axis_preceding_sibling: + { + for (xml_node_struct* c = n->prev_sibling_c; c->next_sibling; c = c->prev_sibling_c) + if (step_push(ns, c, alloc) & once) + return; + + break; + } + + case axis_following: + { + xml_node_struct* cur = n; + + // exit from this node so that we don't include descendants + while (!cur->next_sibling) + { + cur = cur->parent; + + if (!cur) return; + } + + cur = cur->next_sibling; + + while (cur) + { + if (step_push(ns, cur, alloc) & once) + return; + + if (cur->first_child) + cur = cur->first_child; + else + { + while (!cur->next_sibling) + { + cur = cur->parent; + + if (!cur) return; + } + + cur = cur->next_sibling; + } + } + + break; + } + + case axis_preceding: + { + xml_node_struct* cur = n; + + // exit from this node so that we don't include descendants + while (!cur->prev_sibling_c->next_sibling) + { + cur = cur->parent; + + if (!cur) return; + } + + cur = cur->prev_sibling_c; + + while (cur) + { + if (cur->first_child) + cur = cur->first_child->prev_sibling_c; + else + { + // leaf node, can't be ancestor + if (step_push(ns, cur, alloc) & once) + return; + + while (!cur->prev_sibling_c->next_sibling) + { + cur = cur->parent; + + if (!cur) return; + + if (!node_is_ancestor(cur, n)) + if (step_push(ns, cur, alloc) & once) + return; + } + + cur = cur->prev_sibling_c; + } + } + + break; + } + + case axis_ancestor: + case axis_ancestor_or_self: + { + if (axis == axis_ancestor_or_self) + if (step_push(ns, n, alloc) & once) + return; + + xml_node_struct* cur = n->parent; + + while (cur) + { + if (step_push(ns, cur, alloc) & once) + return; + + cur = cur->parent; + } + + break; + } + + case axis_self: + { + step_push(ns, n, alloc); + + break; + } + + case axis_parent: + { + if (n->parent) + step_push(ns, n->parent, alloc); + + break; + } + + default: + assert(false && "Unimplemented axis"); // unreachable + } + } + + template void step_fill(xpath_node_set_raw& ns, xml_attribute_struct* a, xml_node_struct* p, xpath_allocator* alloc, bool once, T v) + { + const axis_t axis = T::axis; + + switch (axis) + { + case axis_ancestor: + case axis_ancestor_or_self: + { + if (axis == axis_ancestor_or_self && _test == nodetest_type_node) // reject attributes based on principal node type test + if (step_push(ns, a, p, alloc) & once) + return; + + xml_node_struct* cur = p; + + while (cur) + { + if (step_push(ns, cur, alloc) & once) + return; + + cur = cur->parent; + } + + break; + } + + case axis_descendant_or_self: + case axis_self: + { + if (_test == nodetest_type_node) // reject attributes based on principal node type test + step_push(ns, a, p, alloc); + + break; + } + + case axis_following: + { + xml_node_struct* cur = p; + + while (cur) + { + if (cur->first_child) + cur = cur->first_child; + else + { + while (!cur->next_sibling) + { + cur = cur->parent; + + if (!cur) return; + } + + cur = cur->next_sibling; + } + + if (step_push(ns, cur, alloc) & once) + return; + } + + break; + } + + case axis_parent: + { + step_push(ns, p, alloc); + + break; + } + + case axis_preceding: + { + // preceding:: axis does not include attribute nodes and attribute ancestors (they are the same as parent's ancestors), so we can reuse node preceding + step_fill(ns, p, alloc, once, v); + break; + } + + default: + assert(false && "Unimplemented axis"); // unreachable + } + } + + template void step_fill(xpath_node_set_raw& ns, const xpath_node& xn, xpath_allocator* alloc, bool once, T v) + { + const axis_t axis = T::axis; + const bool axis_has_attributes = (axis == axis_ancestor || axis == axis_ancestor_or_self || axis == axis_descendant_or_self || axis == axis_following || axis == axis_parent || axis == axis_preceding || axis == axis_self); + + if (xn.node()) + step_fill(ns, xn.node().internal_object(), alloc, once, v); + else if (axis_has_attributes && xn.attribute() && xn.parent()) + step_fill(ns, xn.attribute().internal_object(), xn.parent().internal_object(), alloc, once, v); + } + + template xpath_node_set_raw step_do(const xpath_context& c, const xpath_stack& stack, nodeset_eval_t eval, T v) + { + const axis_t axis = T::axis; + const bool axis_reverse = (axis == axis_ancestor || axis == axis_ancestor_or_self || axis == axis_preceding || axis == axis_preceding_sibling); + const xpath_node_set::type_t axis_type = axis_reverse ? xpath_node_set::type_sorted_reverse : xpath_node_set::type_sorted; + + bool once = + (axis == axis_attribute && _test == nodetest_name) || + (!_right && eval_once(axis_type, eval)) || + // coverity[mixed_enums] + (_right && !_right->_next && _right->_test == predicate_constant_one); + + xpath_node_set_raw ns; + ns.set_type(axis_type); + + if (_left) + { + xpath_node_set_raw s = _left->eval_node_set(c, stack, nodeset_eval_all); + + // self axis preserves the original order + if (axis == axis_self) ns.set_type(s.type()); + + for (const xpath_node* it = s.begin(); it != s.end(); ++it) + { + size_t size = ns.size(); + + // in general, all axes generate elements in a particular order, but there is no order guarantee if axis is applied to two nodes + if (axis != axis_self && size != 0) ns.set_type(xpath_node_set::type_unsorted); + + step_fill(ns, *it, stack.result, once, v); + if (_right) apply_predicates(ns, size, stack, eval); + } + } + else + { + step_fill(ns, c.n, stack.result, once, v); + if (_right) apply_predicates(ns, 0, stack, eval); + } + + // child, attribute and self axes always generate unique set of nodes + // for other axis, if the set stayed sorted, it stayed unique because the traversal algorithms do not visit the same node twice + if (axis != axis_child && axis != axis_attribute && axis != axis_self && ns.type() == xpath_node_set::type_unsorted) + ns.remove_duplicates(stack.temp); + + return ns; + } + + public: + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, const char_t* value): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(0), _right(0), _next(0) + { + assert(type == ast_string_constant); + _data.string = value; + } + + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, double value): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(0), _right(0), _next(0) + { + assert(type == ast_number_constant); + _data.number = value; + } + + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, xpath_variable* value): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(0), _right(0), _next(0) + { + assert(type == ast_variable); + _data.variable = value; + } + + xpath_ast_node(ast_type_t type, xpath_value_type rettype_, xpath_ast_node* left = 0, xpath_ast_node* right = 0): + _type(static_cast(type)), _rettype(static_cast(rettype_)), _axis(0), _test(0), _left(left), _right(right), _next(0) + { + } + + xpath_ast_node(ast_type_t type, xpath_ast_node* left, axis_t axis, nodetest_t test, const char_t* contents): + _type(static_cast(type)), _rettype(xpath_type_node_set), _axis(static_cast(axis)), _test(static_cast(test)), _left(left), _right(0), _next(0) + { + assert(type == ast_step); + _data.nodetest = contents; + } + + xpath_ast_node(ast_type_t type, xpath_ast_node* left, xpath_ast_node* right, predicate_t test): + _type(static_cast(type)), _rettype(xpath_type_node_set), _axis(0), _test(static_cast(test)), _left(left), _right(right), _next(0) + { + assert(type == ast_filter || type == ast_predicate); + } + + void set_next(xpath_ast_node* value) + { + _next = value; + } + + void set_right(xpath_ast_node* value) + { + _right = value; + } + + bool eval_boolean(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_op_or: + return _left->eval_boolean(c, stack) || _right->eval_boolean(c, stack); + + case ast_op_and: + return _left->eval_boolean(c, stack) && _right->eval_boolean(c, stack); + + case ast_op_equal: + return compare_eq(_left, _right, c, stack, equal_to()); + + case ast_op_not_equal: + return compare_eq(_left, _right, c, stack, not_equal_to()); + + case ast_op_less: + return compare_rel(_left, _right, c, stack, less()); + + case ast_op_greater: + return compare_rel(_right, _left, c, stack, less()); + + case ast_op_less_or_equal: + return compare_rel(_left, _right, c, stack, less_equal()); + + case ast_op_greater_or_equal: + return compare_rel(_right, _left, c, stack, less_equal()); + + case ast_func_starts_with: + { + xpath_allocator_capture cr(stack.result); + + xpath_string lr = _left->eval_string(c, stack); + xpath_string rr = _right->eval_string(c, stack); + + return starts_with(lr.c_str(), rr.c_str()); + } + + case ast_func_contains: + { + xpath_allocator_capture cr(stack.result); + + xpath_string lr = _left->eval_string(c, stack); + xpath_string rr = _right->eval_string(c, stack); + + return find_substring(lr.c_str(), rr.c_str()) != 0; + } + + case ast_func_boolean: + return _left->eval_boolean(c, stack); + + case ast_func_not: + return !_left->eval_boolean(c, stack); + + case ast_func_true: + return true; + + case ast_func_false: + return false; + + case ast_func_lang: + { + if (c.n.attribute()) return false; + + xpath_allocator_capture cr(stack.result); + + xpath_string lang = _left->eval_string(c, stack); + + for (xml_node n = c.n.node(); n; n = n.parent()) + { + xml_attribute a = n.attribute(PUGIXML_TEXT("xml:lang")); + + if (a) + { + const char_t* value = a.value(); + + // strnicmp / strncasecmp is not portable + for (const char_t* lit = lang.c_str(); *lit; ++lit) + { + if (tolower_ascii(*lit) != tolower_ascii(*value)) return false; + ++value; + } + + return *value == 0 || *value == '-'; + } + } + + return false; + } + + case ast_opt_compare_attribute: + { + const char_t* value = (_right->_type == ast_string_constant) ? _right->_data.string : _right->_data.variable->get_string(); + + xml_attribute attr = c.n.node().attribute(_left->_data.nodetest); + + return attr && strequal(attr.value(), value) && is_xpath_attribute(attr.name()); + } + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_boolean) + return _data.variable->get_boolean(); + + // variable needs to be converted to the correct type, this is handled by the fallthrough block below + break; + } + + default: + ; + } + + // none of the ast types that return the value directly matched, we need to perform type conversion + switch (_rettype) + { + case xpath_type_number: + return convert_number_to_boolean(eval_number(c, stack)); + + case xpath_type_string: + { + xpath_allocator_capture cr(stack.result); + + return !eval_string(c, stack).empty(); + } + + case xpath_type_node_set: + { + xpath_allocator_capture cr(stack.result); + + return !eval_node_set(c, stack, nodeset_eval_any).empty(); + } + + default: + assert(false && "Wrong expression for return type boolean"); // unreachable + return false; + } + } + + double eval_number(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_op_add: + return _left->eval_number(c, stack) + _right->eval_number(c, stack); + + case ast_op_subtract: + return _left->eval_number(c, stack) - _right->eval_number(c, stack); + + case ast_op_multiply: + return _left->eval_number(c, stack) * _right->eval_number(c, stack); + + case ast_op_divide: + return _left->eval_number(c, stack) / _right->eval_number(c, stack); + + case ast_op_mod: + return fmod(_left->eval_number(c, stack), _right->eval_number(c, stack)); + + case ast_op_negate: + return -_left->eval_number(c, stack); + + case ast_number_constant: + return _data.number; + + case ast_func_last: + return static_cast(c.size); + + case ast_func_position: + return static_cast(c.position); + + case ast_func_count: + { + xpath_allocator_capture cr(stack.result); + + return static_cast(_left->eval_node_set(c, stack, nodeset_eval_all).size()); + } + + case ast_func_string_length_0: + { + xpath_allocator_capture cr(stack.result); + + return static_cast(string_value(c.n, stack.result).length()); + } + + case ast_func_string_length_1: + { + xpath_allocator_capture cr(stack.result); + + return static_cast(_left->eval_string(c, stack).length()); + } + + case ast_func_number_0: + { + xpath_allocator_capture cr(stack.result); + + return convert_string_to_number(string_value(c.n, stack.result).c_str()); + } + + case ast_func_number_1: + return _left->eval_number(c, stack); + + case ast_func_sum: + { + xpath_allocator_capture cr(stack.result); + + double r = 0; + + xpath_node_set_raw ns = _left->eval_node_set(c, stack, nodeset_eval_all); + + for (const xpath_node* it = ns.begin(); it != ns.end(); ++it) + { + xpath_allocator_capture cri(stack.result); + + r += convert_string_to_number(string_value(*it, stack.result).c_str()); + } + + return r; + } + + case ast_func_floor: + { + double r = _left->eval_number(c, stack); + + return r == r ? floor(r) : r; + } + + case ast_func_ceiling: + { + double r = _left->eval_number(c, stack); + + return r == r ? ceil(r) : r; + } + + case ast_func_round: + return round_nearest_nzero(_left->eval_number(c, stack)); + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_number) + return _data.variable->get_number(); + + // variable needs to be converted to the correct type, this is handled by the fallthrough block below + break; + } + + default: + ; + } + + // none of the ast types that return the value directly matched, we need to perform type conversion + switch (_rettype) + { + case xpath_type_boolean: + return eval_boolean(c, stack) ? 1 : 0; + + case xpath_type_string: + { + xpath_allocator_capture cr(stack.result); + + return convert_string_to_number(eval_string(c, stack).c_str()); + } + + case xpath_type_node_set: + { + xpath_allocator_capture cr(stack.result); + + return convert_string_to_number(eval_string(c, stack).c_str()); + } + + default: + assert(false && "Wrong expression for return type number"); // unreachable + return 0; + } + } + + xpath_string eval_string_concat(const xpath_context& c, const xpath_stack& stack) + { + assert(_type == ast_func_concat); + + xpath_allocator_capture ct(stack.temp); + + // count the string number + size_t count = 1; + for (xpath_ast_node* nc = _right; nc; nc = nc->_next) count++; + + // allocate a buffer for temporary string objects + xpath_string* buffer = static_cast(stack.temp->allocate(count * sizeof(xpath_string))); + if (!buffer) return xpath_string(); + + // evaluate all strings to temporary stack + xpath_stack swapped_stack = {stack.temp, stack.result}; + + buffer[0] = _left->eval_string(c, swapped_stack); + + size_t pos = 1; + for (xpath_ast_node* n = _right; n; n = n->_next, ++pos) buffer[pos] = n->eval_string(c, swapped_stack); + assert(pos == count); + + // get total length + size_t length = 0; + for (size_t i = 0; i < count; ++i) length += buffer[i].length(); + + // create final string + char_t* result = static_cast(stack.result->allocate((length + 1) * sizeof(char_t))); + if (!result) return xpath_string(); + + char_t* ri = result; + + for (size_t j = 0; j < count; ++j) + for (const char_t* bi = buffer[j].c_str(); *bi; ++bi) + *ri++ = *bi; + + *ri = 0; + + return xpath_string::from_heap_preallocated(result, ri); + } + + xpath_string eval_string(const xpath_context& c, const xpath_stack& stack) + { + switch (_type) + { + case ast_string_constant: + return xpath_string::from_const(_data.string); + + case ast_func_local_name_0: + { + xpath_node na = c.n; + + return xpath_string::from_const(local_name(na)); + } + + case ast_func_local_name_1: + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ns = _left->eval_node_set(c, stack, nodeset_eval_first); + xpath_node na = ns.first(); + + return xpath_string::from_const(local_name(na)); + } + + case ast_func_name_0: + { + xpath_node na = c.n; + + return xpath_string::from_const(qualified_name(na)); + } + + case ast_func_name_1: + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ns = _left->eval_node_set(c, stack, nodeset_eval_first); + xpath_node na = ns.first(); + + return xpath_string::from_const(qualified_name(na)); + } + + case ast_func_namespace_uri_0: + { + xpath_node na = c.n; + + return xpath_string::from_const(namespace_uri(na)); + } + + case ast_func_namespace_uri_1: + { + xpath_allocator_capture cr(stack.result); + + xpath_node_set_raw ns = _left->eval_node_set(c, stack, nodeset_eval_first); + xpath_node na = ns.first(); + + return xpath_string::from_const(namespace_uri(na)); + } + + case ast_func_string_0: + return string_value(c.n, stack.result); + + case ast_func_string_1: + return _left->eval_string(c, stack); + + case ast_func_concat: + return eval_string_concat(c, stack); + + case ast_func_substring_before: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + xpath_string p = _right->eval_string(c, swapped_stack); + + const char_t* pos = find_substring(s.c_str(), p.c_str()); + + return pos ? xpath_string::from_heap(s.c_str(), pos, stack.result) : xpath_string(); + } + + case ast_func_substring_after: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + xpath_string p = _right->eval_string(c, swapped_stack); + + const char_t* pos = find_substring(s.c_str(), p.c_str()); + if (!pos) return xpath_string(); + + const char_t* rbegin = pos + p.length(); + const char_t* rend = s.c_str() + s.length(); + + return s.uses_heap() ? xpath_string::from_heap(rbegin, rend, stack.result) : xpath_string::from_const(rbegin); + } + + case ast_func_substring_2: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + size_t s_length = s.length(); + + double first = round_nearest(_right->eval_number(c, stack)); + + if (is_nan(first)) return xpath_string(); // NaN + else if (first >= static_cast(s_length + 1)) return xpath_string(); + + size_t pos = first < 1 ? 1 : static_cast(first); + assert(1 <= pos && pos <= s_length + 1); + + const char_t* rbegin = s.c_str() + (pos - 1); + const char_t* rend = s.c_str() + s.length(); + + return s.uses_heap() ? xpath_string::from_heap(rbegin, rend, stack.result) : xpath_string::from_const(rbegin); + } + + case ast_func_substring_3: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, swapped_stack); + size_t s_length = s.length(); + + double first = round_nearest(_right->eval_number(c, stack)); + double last = first + round_nearest(_right->_next->eval_number(c, stack)); + + if (is_nan(first) || is_nan(last)) return xpath_string(); + else if (first >= static_cast(s_length + 1)) return xpath_string(); + else if (first >= last) return xpath_string(); + else if (last < 1) return xpath_string(); + + size_t pos = first < 1 ? 1 : static_cast(first); + size_t end = last >= static_cast(s_length + 1) ? s_length + 1 : static_cast(last); + + assert(1 <= pos && pos <= end && end <= s_length + 1); + const char_t* rbegin = s.c_str() + (pos - 1); + const char_t* rend = s.c_str() + (end - 1); + + return (end == s_length + 1 && !s.uses_heap()) ? xpath_string::from_const(rbegin) : xpath_string::from_heap(rbegin, rend, stack.result); + } + + case ast_func_normalize_space_0: + { + xpath_string s = string_value(c.n, stack.result); + + char_t* begin = s.data(stack.result); + if (!begin) return xpath_string(); + + char_t* end = normalize_space(begin); + + return xpath_string::from_heap_preallocated(begin, end); + } + + case ast_func_normalize_space_1: + { + xpath_string s = _left->eval_string(c, stack); + + char_t* begin = s.data(stack.result); + if (!begin) return xpath_string(); + + char_t* end = normalize_space(begin); + + return xpath_string::from_heap_preallocated(begin, end); + } + + case ast_func_translate: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_string s = _left->eval_string(c, stack); + xpath_string from = _right->eval_string(c, swapped_stack); + xpath_string to = _right->_next->eval_string(c, swapped_stack); + + char_t* begin = s.data(stack.result); + if (!begin) return xpath_string(); + + char_t* end = translate(begin, from.c_str(), to.c_str(), to.length()); + + return xpath_string::from_heap_preallocated(begin, end); + } + + case ast_opt_translate_table: + { + xpath_string s = _left->eval_string(c, stack); + + char_t* begin = s.data(stack.result); + if (!begin) return xpath_string(); + + char_t* end = translate_table(begin, _data.table); + + return xpath_string::from_heap_preallocated(begin, end); + } + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_string) + return xpath_string::from_const(_data.variable->get_string()); + + // variable needs to be converted to the correct type, this is handled by the fallthrough block below + break; + } + + default: + ; + } + + // none of the ast types that return the value directly matched, we need to perform type conversion + switch (_rettype) + { + case xpath_type_boolean: + return xpath_string::from_const(eval_boolean(c, stack) ? PUGIXML_TEXT("true") : PUGIXML_TEXT("false")); + + case xpath_type_number: + return convert_number_to_string(eval_number(c, stack), stack.result); + + case xpath_type_node_set: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_node_set_raw ns = eval_node_set(c, swapped_stack, nodeset_eval_first); + return ns.empty() ? xpath_string() : string_value(ns.first(), stack.result); + } + + default: + assert(false && "Wrong expression for return type string"); // unreachable + return xpath_string(); + } + } + + xpath_node_set_raw eval_node_set(const xpath_context& c, const xpath_stack& stack, nodeset_eval_t eval) + { + switch (_type) + { + case ast_op_union: + { + xpath_allocator_capture cr(stack.temp); + + xpath_stack swapped_stack = {stack.temp, stack.result}; + + xpath_node_set_raw ls = _left->eval_node_set(c, stack, eval); + xpath_node_set_raw rs = _right->eval_node_set(c, swapped_stack, eval); + + // we can optimize merging two sorted sets, but this is a very rare operation, so don't bother + ls.set_type(xpath_node_set::type_unsorted); + + ls.append(rs.begin(), rs.end(), stack.result); + ls.remove_duplicates(stack.temp); + + return ls; + } + + case ast_filter: + { + xpath_node_set_raw set = _left->eval_node_set(c, stack, _test == predicate_constant_one ? nodeset_eval_first : nodeset_eval_all); + + // either expression is a number or it contains position() call; sort by document order + if (_test != predicate_posinv) set.sort_do(); + + bool once = eval_once(set.type(), eval); + + apply_predicate(set, 0, stack, once); + + return set; + } + + case ast_func_id: + return xpath_node_set_raw(); + + case ast_step: + { + switch (_axis) + { + case axis_ancestor: + return step_do(c, stack, eval, axis_to_type()); + + case axis_ancestor_or_self: + return step_do(c, stack, eval, axis_to_type()); + + case axis_attribute: + return step_do(c, stack, eval, axis_to_type()); + + case axis_child: + return step_do(c, stack, eval, axis_to_type()); + + case axis_descendant: + return step_do(c, stack, eval, axis_to_type()); + + case axis_descendant_or_self: + return step_do(c, stack, eval, axis_to_type()); + + case axis_following: + return step_do(c, stack, eval, axis_to_type()); + + case axis_following_sibling: + return step_do(c, stack, eval, axis_to_type()); + + case axis_namespace: + // namespaced axis is not supported + return xpath_node_set_raw(); + + case axis_parent: + return step_do(c, stack, eval, axis_to_type()); + + case axis_preceding: + return step_do(c, stack, eval, axis_to_type()); + + case axis_preceding_sibling: + return step_do(c, stack, eval, axis_to_type()); + + case axis_self: + return step_do(c, stack, eval, axis_to_type()); + + default: + assert(false && "Unknown axis"); // unreachable + return xpath_node_set_raw(); + } + } + + case ast_step_root: + { + assert(!_right); // root step can't have any predicates + + xpath_node_set_raw ns; + + ns.set_type(xpath_node_set::type_sorted); + + if (c.n.node()) ns.push_back(c.n.node().root(), stack.result); + else if (c.n.attribute()) ns.push_back(c.n.parent().root(), stack.result); + + return ns; + } + + case ast_variable: + { + assert(_rettype == _data.variable->type()); + + if (_rettype == xpath_type_node_set) + { + const xpath_node_set& s = _data.variable->get_node_set(); + + xpath_node_set_raw ns; + + ns.set_type(s.type()); + ns.append(s.begin(), s.end(), stack.result); + + return ns; + } + + // variable needs to be converted to the correct type, this is handled by the fallthrough block below + break; + } + + default: + ; + } + + // none of the ast types that return the value directly matched, but conversions to node set are invalid + assert(false && "Wrong expression for return type node set"); // unreachable + return xpath_node_set_raw(); + } + + void optimize(xpath_allocator* alloc) + { + if (_left) + _left->optimize(alloc); + + if (_right) + _right->optimize(alloc); + + if (_next) + _next->optimize(alloc); + + // coverity[var_deref_model] + optimize_self(alloc); + } + + void optimize_self(xpath_allocator* alloc) + { + // Rewrite [position()=expr] with [expr] + // Note that this step has to go before classification to recognize [position()=1] + if ((_type == ast_filter || _type == ast_predicate) && + _right && // workaround for clang static analyzer (_right is never null for ast_filter/ast_predicate) + _right->_type == ast_op_equal && _right->_left->_type == ast_func_position && _right->_right->_rettype == xpath_type_number) + { + _right = _right->_right; + } + + // Classify filter/predicate ops to perform various optimizations during evaluation + if ((_type == ast_filter || _type == ast_predicate) && _right) // workaround for clang static analyzer (_right is never null for ast_filter/ast_predicate) + { + assert(_test == predicate_default); + + if (_right->_type == ast_number_constant && _right->_data.number == 1.0) + _test = predicate_constant_one; + else if (_right->_rettype == xpath_type_number && (_right->_type == ast_number_constant || _right->_type == ast_variable || _right->_type == ast_func_last)) + _test = predicate_constant; + else if (_right->_rettype != xpath_type_number && _right->is_posinv_expr()) + _test = predicate_posinv; + } + + // Rewrite descendant-or-self::node()/child::foo with descendant::foo + // The former is a full form of //foo, the latter is much faster since it executes the node test immediately + // Do a similar kind of rewrite for self/descendant/descendant-or-self axes + // Note that we only rewrite positionally invariant steps (//foo[1] != /descendant::foo[1]) + if (_type == ast_step && (_axis == axis_child || _axis == axis_self || _axis == axis_descendant || _axis == axis_descendant_or_self) && + _left && _left->_type == ast_step && _left->_axis == axis_descendant_or_self && _left->_test == nodetest_type_node && !_left->_right && + is_posinv_step()) + { + if (_axis == axis_child || _axis == axis_descendant) + _axis = axis_descendant; + else + _axis = axis_descendant_or_self; + + _left = _left->_left; + } + + // Use optimized lookup table implementation for translate() with constant arguments + if (_type == ast_func_translate && + _right && // workaround for clang static analyzer (_right is never null for ast_func_translate) + _right->_type == ast_string_constant && _right->_next->_type == ast_string_constant) + { + unsigned char* table = translate_table_generate(alloc, _right->_data.string, _right->_next->_data.string); + + if (table) + { + _type = ast_opt_translate_table; + _data.table = table; + } + } + + // Use optimized path for @attr = 'value' or @attr = $value + if (_type == ast_op_equal && + _left && _right && // workaround for clang static analyzer and Coverity (_left and _right are never null for ast_op_equal) + // coverity[mixed_enums] + _left->_type == ast_step && _left->_axis == axis_attribute && _left->_test == nodetest_name && !_left->_left && !_left->_right && + (_right->_type == ast_string_constant || (_right->_type == ast_variable && _right->_rettype == xpath_type_string))) + { + _type = ast_opt_compare_attribute; + } + } + + bool is_posinv_expr() const + { + switch (_type) + { + case ast_func_position: + case ast_func_last: + return false; + + case ast_string_constant: + case ast_number_constant: + case ast_variable: + return true; + + case ast_step: + case ast_step_root: + return true; + + case ast_predicate: + case ast_filter: + return true; + + default: + if (_left && !_left->is_posinv_expr()) return false; + + for (xpath_ast_node* n = _right; n; n = n->_next) + if (!n->is_posinv_expr()) return false; + + return true; + } + } + + bool is_posinv_step() const + { + assert(_type == ast_step); + + for (xpath_ast_node* n = _right; n; n = n->_next) + { + assert(n->_type == ast_predicate); + + if (n->_test != predicate_posinv) + return false; + } + + return true; + } + + xpath_value_type rettype() const + { + return static_cast(_rettype); + } + }; + + static const size_t xpath_ast_depth_limit = + #ifdef PUGIXML_XPATH_DEPTH_LIMIT + PUGIXML_XPATH_DEPTH_LIMIT + #else + 1024 + #endif + ; + + struct xpath_parser + { + xpath_allocator* _alloc; + xpath_lexer _lexer; + + const char_t* _query; + xpath_variable_set* _variables; + + xpath_parse_result* _result; + + char_t _scratch[32]; + + size_t _depth; + + xpath_ast_node* error(const char* message) + { + _result->error = message; + _result->offset = _lexer.current_pos() - _query; + + return 0; + } + + xpath_ast_node* error_oom() + { + assert(_alloc->_error); + *_alloc->_error = true; + + return 0; + } + + xpath_ast_node* error_rec() + { + return error("Exceeded maximum allowed query depth"); + } + + void* alloc_node() + { + return _alloc->allocate(sizeof(xpath_ast_node)); + } + + xpath_ast_node* alloc_node(ast_type_t type, xpath_value_type rettype, const char_t* value) + { + void* memory = alloc_node(); + return memory ? new (memory) xpath_ast_node(type, rettype, value) : 0; + } + + xpath_ast_node* alloc_node(ast_type_t type, xpath_value_type rettype, double value) + { + void* memory = alloc_node(); + return memory ? new (memory) xpath_ast_node(type, rettype, value) : 0; + } + + xpath_ast_node* alloc_node(ast_type_t type, xpath_value_type rettype, xpath_variable* value) + { + void* memory = alloc_node(); + return memory ? new (memory) xpath_ast_node(type, rettype, value) : 0; + } + + xpath_ast_node* alloc_node(ast_type_t type, xpath_value_type rettype, xpath_ast_node* left = 0, xpath_ast_node* right = 0) + { + void* memory = alloc_node(); + return memory ? new (memory) xpath_ast_node(type, rettype, left, right) : 0; + } + + xpath_ast_node* alloc_node(ast_type_t type, xpath_ast_node* left, axis_t axis, nodetest_t test, const char_t* contents) + { + void* memory = alloc_node(); + return memory ? new (memory) xpath_ast_node(type, left, axis, test, contents) : 0; + } + + xpath_ast_node* alloc_node(ast_type_t type, xpath_ast_node* left, xpath_ast_node* right, predicate_t test) + { + void* memory = alloc_node(); + return memory ? new (memory) xpath_ast_node(type, left, right, test) : 0; + } + + const char_t* alloc_string(const xpath_lexer_string& value) + { + if (!value.begin) + return PUGIXML_TEXT(""); + + size_t length = static_cast(value.end - value.begin); + + char_t* c = static_cast(_alloc->allocate((length + 1) * sizeof(char_t))); + if (!c) return 0; + + memcpy(c, value.begin, length * sizeof(char_t)); + c[length] = 0; + + return c; + } + + xpath_ast_node* parse_function(const xpath_lexer_string& name, size_t argc, xpath_ast_node* args[2]) + { + switch (name.begin[0]) + { + case 'b': + if (name == PUGIXML_TEXT("boolean") && argc == 1) + return alloc_node(ast_func_boolean, xpath_type_boolean, args[0]); + + break; + + case 'c': + if (name == PUGIXML_TEXT("count") && argc == 1) + { + if (args[0]->rettype() != xpath_type_node_set) return error("Function has to be applied to node set"); + return alloc_node(ast_func_count, xpath_type_number, args[0]); + } + else if (name == PUGIXML_TEXT("contains") && argc == 2) + return alloc_node(ast_func_contains, xpath_type_boolean, args[0], args[1]); + else if (name == PUGIXML_TEXT("concat") && argc >= 2) + return alloc_node(ast_func_concat, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("ceiling") && argc == 1) + return alloc_node(ast_func_ceiling, xpath_type_number, args[0]); + + break; + + case 'f': + if (name == PUGIXML_TEXT("false") && argc == 0) + return alloc_node(ast_func_false, xpath_type_boolean); + else if (name == PUGIXML_TEXT("floor") && argc == 1) + return alloc_node(ast_func_floor, xpath_type_number, args[0]); + + break; + + case 'i': + if (name == PUGIXML_TEXT("id") && argc == 1) + return alloc_node(ast_func_id, xpath_type_node_set, args[0]); + + break; + + case 'l': + if (name == PUGIXML_TEXT("last") && argc == 0) + return alloc_node(ast_func_last, xpath_type_number); + else if (name == PUGIXML_TEXT("lang") && argc == 1) + return alloc_node(ast_func_lang, xpath_type_boolean, args[0]); + else if (name == PUGIXML_TEXT("local-name") && argc <= 1) + { + if (argc == 1 && args[0]->rettype() != xpath_type_node_set) return error("Function has to be applied to node set"); + return alloc_node(argc == 0 ? ast_func_local_name_0 : ast_func_local_name_1, xpath_type_string, args[0]); + } + + break; + + case 'n': + if (name == PUGIXML_TEXT("name") && argc <= 1) + { + if (argc == 1 && args[0]->rettype() != xpath_type_node_set) return error("Function has to be applied to node set"); + return alloc_node(argc == 0 ? ast_func_name_0 : ast_func_name_1, xpath_type_string, args[0]); + } + else if (name == PUGIXML_TEXT("namespace-uri") && argc <= 1) + { + if (argc == 1 && args[0]->rettype() != xpath_type_node_set) return error("Function has to be applied to node set"); + return alloc_node(argc == 0 ? ast_func_namespace_uri_0 : ast_func_namespace_uri_1, xpath_type_string, args[0]); + } + else if (name == PUGIXML_TEXT("normalize-space") && argc <= 1) + return alloc_node(argc == 0 ? ast_func_normalize_space_0 : ast_func_normalize_space_1, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("not") && argc == 1) + return alloc_node(ast_func_not, xpath_type_boolean, args[0]); + else if (name == PUGIXML_TEXT("number") && argc <= 1) + return alloc_node(argc == 0 ? ast_func_number_0 : ast_func_number_1, xpath_type_number, args[0]); + + break; + + case 'p': + if (name == PUGIXML_TEXT("position") && argc == 0) + return alloc_node(ast_func_position, xpath_type_number); + + break; + + case 'r': + if (name == PUGIXML_TEXT("round") && argc == 1) + return alloc_node(ast_func_round, xpath_type_number, args[0]); + + break; + + case 's': + if (name == PUGIXML_TEXT("string") && argc <= 1) + return alloc_node(argc == 0 ? ast_func_string_0 : ast_func_string_1, xpath_type_string, args[0]); + else if (name == PUGIXML_TEXT("string-length") && argc <= 1) + return alloc_node(argc == 0 ? ast_func_string_length_0 : ast_func_string_length_1, xpath_type_number, args[0]); + else if (name == PUGIXML_TEXT("starts-with") && argc == 2) + return alloc_node(ast_func_starts_with, xpath_type_boolean, args[0], args[1]); + else if (name == PUGIXML_TEXT("substring-before") && argc == 2) + return alloc_node(ast_func_substring_before, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("substring-after") && argc == 2) + return alloc_node(ast_func_substring_after, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("substring") && (argc == 2 || argc == 3)) + return alloc_node(argc == 2 ? ast_func_substring_2 : ast_func_substring_3, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("sum") && argc == 1) + { + if (args[0]->rettype() != xpath_type_node_set) return error("Function has to be applied to node set"); + return alloc_node(ast_func_sum, xpath_type_number, args[0]); + } + + break; + + case 't': + if (name == PUGIXML_TEXT("translate") && argc == 3) + return alloc_node(ast_func_translate, xpath_type_string, args[0], args[1]); + else if (name == PUGIXML_TEXT("true") && argc == 0) + return alloc_node(ast_func_true, xpath_type_boolean); + + break; + + default: + break; + } + + return error("Unrecognized function or wrong parameter count"); + } + + axis_t parse_axis_name(const xpath_lexer_string& name, bool& specified) + { + specified = true; + + switch (name.begin[0]) + { + case 'a': + if (name == PUGIXML_TEXT("ancestor")) + return axis_ancestor; + else if (name == PUGIXML_TEXT("ancestor-or-self")) + return axis_ancestor_or_self; + else if (name == PUGIXML_TEXT("attribute")) + return axis_attribute; + + break; + + case 'c': + if (name == PUGIXML_TEXT("child")) + return axis_child; + + break; + + case 'd': + if (name == PUGIXML_TEXT("descendant")) + return axis_descendant; + else if (name == PUGIXML_TEXT("descendant-or-self")) + return axis_descendant_or_self; + + break; + + case 'f': + if (name == PUGIXML_TEXT("following")) + return axis_following; + else if (name == PUGIXML_TEXT("following-sibling")) + return axis_following_sibling; + + break; + + case 'n': + if (name == PUGIXML_TEXT("namespace")) + return axis_namespace; + + break; + + case 'p': + if (name == PUGIXML_TEXT("parent")) + return axis_parent; + else if (name == PUGIXML_TEXT("preceding")) + return axis_preceding; + else if (name == PUGIXML_TEXT("preceding-sibling")) + return axis_preceding_sibling; + + break; + + case 's': + if (name == PUGIXML_TEXT("self")) + return axis_self; + + break; + + default: + break; + } + + specified = false; + return axis_child; + } + + nodetest_t parse_node_test_type(const xpath_lexer_string& name) + { + switch (name.begin[0]) + { + case 'c': + if (name == PUGIXML_TEXT("comment")) + return nodetest_type_comment; + + break; + + case 'n': + if (name == PUGIXML_TEXT("node")) + return nodetest_type_node; + + break; + + case 'p': + if (name == PUGIXML_TEXT("processing-instruction")) + return nodetest_type_pi; + + break; + + case 't': + if (name == PUGIXML_TEXT("text")) + return nodetest_type_text; + + break; + + default: + break; + } + + return nodetest_none; + } + + // PrimaryExpr ::= VariableReference | '(' Expr ')' | Literal | Number | FunctionCall + xpath_ast_node* parse_primary_expression() + { + switch (_lexer.current()) + { + case lex_var_ref: + { + xpath_lexer_string name = _lexer.contents(); + + if (!_variables) + return error("Unknown variable: variable set is not provided"); + + xpath_variable* var = 0; + if (!get_variable_scratch(_scratch, _variables, name.begin, name.end, &var)) + return error_oom(); + + if (!var) + return error("Unknown variable: variable set does not contain the given name"); + + _lexer.next(); + + return alloc_node(ast_variable, var->type(), var); + } + + case lex_open_brace: + { + _lexer.next(); + + xpath_ast_node* n = parse_expression(); + if (!n) return 0; + + if (_lexer.current() != lex_close_brace) + return error("Expected ')' to match an opening '('"); + + _lexer.next(); + + return n; + } + + case lex_quoted_string: + { + const char_t* value = alloc_string(_lexer.contents()); + if (!value) return 0; + + _lexer.next(); + + return alloc_node(ast_string_constant, xpath_type_string, value); + } + + case lex_number: + { + double value = 0; + + if (!convert_string_to_number_scratch(_scratch, _lexer.contents().begin, _lexer.contents().end, &value)) + return error_oom(); + + _lexer.next(); + + return alloc_node(ast_number_constant, xpath_type_number, value); + } + + case lex_string: + { + xpath_ast_node* args[2] = {0}; + size_t argc = 0; + + xpath_lexer_string function = _lexer.contents(); + _lexer.next(); + + xpath_ast_node* last_arg = 0; + + if (_lexer.current() != lex_open_brace) + return error("Unrecognized function call"); + _lexer.next(); + + size_t old_depth = _depth; + + while (_lexer.current() != lex_close_brace) + { + if (argc > 0) + { + if (_lexer.current() != lex_comma) + return error("No comma between function arguments"); + _lexer.next(); + } + + if (++_depth > xpath_ast_depth_limit) + return error_rec(); + + xpath_ast_node* n = parse_expression(); + if (!n) return 0; + + if (argc < 2) args[argc] = n; + else last_arg->set_next(n); + + argc++; + last_arg = n; + } + + _lexer.next(); + + _depth = old_depth; + + return parse_function(function, argc, args); + } + + default: + return error("Unrecognizable primary expression"); + } + } + + // FilterExpr ::= PrimaryExpr | FilterExpr Predicate + // Predicate ::= '[' PredicateExpr ']' + // PredicateExpr ::= Expr + xpath_ast_node* parse_filter_expression() + { + xpath_ast_node* n = parse_primary_expression(); + if (!n) return 0; + + size_t old_depth = _depth; + + while (_lexer.current() == lex_open_square_brace) + { + _lexer.next(); + + if (++_depth > xpath_ast_depth_limit) + return error_rec(); + + if (n->rettype() != xpath_type_node_set) + return error("Predicate has to be applied to node set"); + + xpath_ast_node* expr = parse_expression(); + if (!expr) return 0; + + n = alloc_node(ast_filter, n, expr, predicate_default); + if (!n) return 0; + + if (_lexer.current() != lex_close_square_brace) + return error("Expected ']' to match an opening '['"); + + _lexer.next(); + } + + _depth = old_depth; + + return n; + } + + // Step ::= AxisSpecifier NodeTest Predicate* | AbbreviatedStep + // AxisSpecifier ::= AxisName '::' | '@'? + // NodeTest ::= NameTest | NodeType '(' ')' | 'processing-instruction' '(' Literal ')' + // NameTest ::= '*' | NCName ':' '*' | QName + // AbbreviatedStep ::= '.' | '..' + xpath_ast_node* parse_step(xpath_ast_node* set) + { + if (set && set->rettype() != xpath_type_node_set) + return error("Step has to be applied to node set"); + + bool axis_specified = false; + axis_t axis = axis_child; // implied child axis + + if (_lexer.current() == lex_axis_attribute) + { + axis = axis_attribute; + axis_specified = true; + + _lexer.next(); + } + else if (_lexer.current() == lex_dot) + { + _lexer.next(); + + if (_lexer.current() == lex_open_square_brace) + return error("Predicates are not allowed after an abbreviated step"); + + return alloc_node(ast_step, set, axis_self, nodetest_type_node, 0); + } + else if (_lexer.current() == lex_double_dot) + { + _lexer.next(); + + if (_lexer.current() == lex_open_square_brace) + return error("Predicates are not allowed after an abbreviated step"); + + return alloc_node(ast_step, set, axis_parent, nodetest_type_node, 0); + } + + nodetest_t nt_type = nodetest_none; + xpath_lexer_string nt_name; + + if (_lexer.current() == lex_string) + { + // node name test + nt_name = _lexer.contents(); + _lexer.next(); + + // was it an axis name? + if (_lexer.current() == lex_double_colon) + { + // parse axis name + if (axis_specified) + return error("Two axis specifiers in one step"); + + axis = parse_axis_name(nt_name, axis_specified); + + if (!axis_specified) + return error("Unknown axis"); + + // read actual node test + _lexer.next(); + + if (_lexer.current() == lex_multiply) + { + nt_type = nodetest_all; + nt_name = xpath_lexer_string(); + _lexer.next(); + } + else if (_lexer.current() == lex_string) + { + nt_name = _lexer.contents(); + _lexer.next(); + } + else + { + return error("Unrecognized node test"); + } + } + + if (nt_type == nodetest_none) + { + // node type test or processing-instruction + if (_lexer.current() == lex_open_brace) + { + _lexer.next(); + + if (_lexer.current() == lex_close_brace) + { + _lexer.next(); + + nt_type = parse_node_test_type(nt_name); + + if (nt_type == nodetest_none) + return error("Unrecognized node type"); + + nt_name = xpath_lexer_string(); + } + else if (nt_name == PUGIXML_TEXT("processing-instruction")) + { + if (_lexer.current() != lex_quoted_string) + return error("Only literals are allowed as arguments to processing-instruction()"); + + nt_type = nodetest_pi; + nt_name = _lexer.contents(); + _lexer.next(); + + if (_lexer.current() != lex_close_brace) + return error("Unmatched brace near processing-instruction()"); + _lexer.next(); + } + else + { + return error("Unmatched brace near node type test"); + } + } + // QName or NCName:* + else + { + if (nt_name.end - nt_name.begin > 2 && nt_name.end[-2] == ':' && nt_name.end[-1] == '*') // NCName:* + { + nt_name.end--; // erase * + + nt_type = nodetest_all_in_namespace; + } + else + { + nt_type = nodetest_name; + } + } + } + } + else if (_lexer.current() == lex_multiply) + { + nt_type = nodetest_all; + _lexer.next(); + } + else + { + return error("Unrecognized node test"); + } + + const char_t* nt_name_copy = alloc_string(nt_name); + if (!nt_name_copy) return 0; + + xpath_ast_node* n = alloc_node(ast_step, set, axis, nt_type, nt_name_copy); + if (!n) return 0; + + size_t old_depth = _depth; + + xpath_ast_node* last = 0; + + while (_lexer.current() == lex_open_square_brace) + { + _lexer.next(); + + if (++_depth > xpath_ast_depth_limit) + return error_rec(); + + xpath_ast_node* expr = parse_expression(); + if (!expr) return 0; + + xpath_ast_node* pred = alloc_node(ast_predicate, 0, expr, predicate_default); + if (!pred) return 0; + + if (_lexer.current() != lex_close_square_brace) + return error("Expected ']' to match an opening '['"); + _lexer.next(); + + if (last) last->set_next(pred); + else n->set_right(pred); + + last = pred; + } + + _depth = old_depth; + + return n; + } + + // RelativeLocationPath ::= Step | RelativeLocationPath '/' Step | RelativeLocationPath '//' Step + xpath_ast_node* parse_relative_location_path(xpath_ast_node* set) + { + xpath_ast_node* n = parse_step(set); + if (!n) return 0; + + size_t old_depth = _depth; + + while (_lexer.current() == lex_slash || _lexer.current() == lex_double_slash) + { + lexeme_t l = _lexer.current(); + _lexer.next(); + + if (l == lex_double_slash) + { + n = alloc_node(ast_step, n, axis_descendant_or_self, nodetest_type_node, 0); + if (!n) return 0; + + ++_depth; + } + + if (++_depth > xpath_ast_depth_limit) + return error_rec(); + + n = parse_step(n); + if (!n) return 0; + } + + _depth = old_depth; + + return n; + } + + // LocationPath ::= RelativeLocationPath | AbsoluteLocationPath + // AbsoluteLocationPath ::= '/' RelativeLocationPath? | '//' RelativeLocationPath + xpath_ast_node* parse_location_path() + { + if (_lexer.current() == lex_slash) + { + _lexer.next(); + + xpath_ast_node* n = alloc_node(ast_step_root, xpath_type_node_set); + if (!n) return 0; + + // relative location path can start from axis_attribute, dot, double_dot, multiply and string lexemes; any other lexeme means standalone root path + lexeme_t l = _lexer.current(); + + if (l == lex_string || l == lex_axis_attribute || l == lex_dot || l == lex_double_dot || l == lex_multiply) + return parse_relative_location_path(n); + else + return n; + } + else if (_lexer.current() == lex_double_slash) + { + _lexer.next(); + + xpath_ast_node* n = alloc_node(ast_step_root, xpath_type_node_set); + if (!n) return 0; + + n = alloc_node(ast_step, n, axis_descendant_or_self, nodetest_type_node, 0); + if (!n) return 0; + + return parse_relative_location_path(n); + } + + // else clause moved outside of if because of bogus warning 'control may reach end of non-void function being inlined' in gcc 4.0.1 + return parse_relative_location_path(0); + } + + // PathExpr ::= LocationPath + // | FilterExpr + // | FilterExpr '/' RelativeLocationPath + // | FilterExpr '//' RelativeLocationPath + // UnionExpr ::= PathExpr | UnionExpr '|' PathExpr + // UnaryExpr ::= UnionExpr | '-' UnaryExpr + xpath_ast_node* parse_path_or_unary_expression() + { + // Clarification. + // PathExpr begins with either LocationPath or FilterExpr. + // FilterExpr begins with PrimaryExpr + // PrimaryExpr begins with '$' in case of it being a variable reference, + // '(' in case of it being an expression, string literal, number constant or + // function call. + if (_lexer.current() == lex_var_ref || _lexer.current() == lex_open_brace || + _lexer.current() == lex_quoted_string || _lexer.current() == lex_number || + _lexer.current() == lex_string) + { + if (_lexer.current() == lex_string) + { + // This is either a function call, or not - if not, we shall proceed with location path + const char_t* state = _lexer.state(); + + while (PUGI__IS_CHARTYPE(*state, ct_space)) ++state; + + if (*state != '(') + return parse_location_path(); + + // This looks like a function call; however this still can be a node-test. Check it. + if (parse_node_test_type(_lexer.contents()) != nodetest_none) + return parse_location_path(); + } + + xpath_ast_node* n = parse_filter_expression(); + if (!n) return 0; + + if (_lexer.current() == lex_slash || _lexer.current() == lex_double_slash) + { + lexeme_t l = _lexer.current(); + _lexer.next(); + + if (l == lex_double_slash) + { + if (n->rettype() != xpath_type_node_set) + return error("Step has to be applied to node set"); + + n = alloc_node(ast_step, n, axis_descendant_or_self, nodetest_type_node, 0); + if (!n) return 0; + } + + // select from location path + return parse_relative_location_path(n); + } + + return n; + } + else if (_lexer.current() == lex_minus) + { + _lexer.next(); + + // precedence 7+ - only parses union expressions + xpath_ast_node* n = parse_expression(7); + if (!n) return 0; + + return alloc_node(ast_op_negate, xpath_type_number, n); + } + else + { + return parse_location_path(); + } + } + + struct binary_op_t + { + ast_type_t asttype; + xpath_value_type rettype; + int precedence; + + binary_op_t(): asttype(ast_unknown), rettype(xpath_type_none), precedence(0) + { + } + + binary_op_t(ast_type_t asttype_, xpath_value_type rettype_, int precedence_): asttype(asttype_), rettype(rettype_), precedence(precedence_) + { + } + + static binary_op_t parse(xpath_lexer& lexer) + { + switch (lexer.current()) + { + case lex_string: + if (lexer.contents() == PUGIXML_TEXT("or")) + return binary_op_t(ast_op_or, xpath_type_boolean, 1); + else if (lexer.contents() == PUGIXML_TEXT("and")) + return binary_op_t(ast_op_and, xpath_type_boolean, 2); + else if (lexer.contents() == PUGIXML_TEXT("div")) + return binary_op_t(ast_op_divide, xpath_type_number, 6); + else if (lexer.contents() == PUGIXML_TEXT("mod")) + return binary_op_t(ast_op_mod, xpath_type_number, 6); + else + return binary_op_t(); + + case lex_equal: + return binary_op_t(ast_op_equal, xpath_type_boolean, 3); + + case lex_not_equal: + return binary_op_t(ast_op_not_equal, xpath_type_boolean, 3); + + case lex_less: + return binary_op_t(ast_op_less, xpath_type_boolean, 4); + + case lex_greater: + return binary_op_t(ast_op_greater, xpath_type_boolean, 4); + + case lex_less_or_equal: + return binary_op_t(ast_op_less_or_equal, xpath_type_boolean, 4); + + case lex_greater_or_equal: + return binary_op_t(ast_op_greater_or_equal, xpath_type_boolean, 4); + + case lex_plus: + return binary_op_t(ast_op_add, xpath_type_number, 5); + + case lex_minus: + return binary_op_t(ast_op_subtract, xpath_type_number, 5); + + case lex_multiply: + return binary_op_t(ast_op_multiply, xpath_type_number, 6); + + case lex_union: + return binary_op_t(ast_op_union, xpath_type_node_set, 7); + + default: + return binary_op_t(); + } + } + }; + + xpath_ast_node* parse_expression_rec(xpath_ast_node* lhs, int limit) + { + binary_op_t op = binary_op_t::parse(_lexer); + + while (op.asttype != ast_unknown && op.precedence >= limit) + { + _lexer.next(); + + if (++_depth > xpath_ast_depth_limit) + return error_rec(); + + xpath_ast_node* rhs = parse_path_or_unary_expression(); + if (!rhs) return 0; + + binary_op_t nextop = binary_op_t::parse(_lexer); + + while (nextop.asttype != ast_unknown && nextop.precedence > op.precedence) + { + rhs = parse_expression_rec(rhs, nextop.precedence); + if (!rhs) return 0; + + nextop = binary_op_t::parse(_lexer); + } + + if (op.asttype == ast_op_union && (lhs->rettype() != xpath_type_node_set || rhs->rettype() != xpath_type_node_set)) + return error("Union operator has to be applied to node sets"); + + lhs = alloc_node(op.asttype, op.rettype, lhs, rhs); + if (!lhs) return 0; + + op = binary_op_t::parse(_lexer); + } + + return lhs; + } + + // Expr ::= OrExpr + // OrExpr ::= AndExpr | OrExpr 'or' AndExpr + // AndExpr ::= EqualityExpr | AndExpr 'and' EqualityExpr + // EqualityExpr ::= RelationalExpr + // | EqualityExpr '=' RelationalExpr + // | EqualityExpr '!=' RelationalExpr + // RelationalExpr ::= AdditiveExpr + // | RelationalExpr '<' AdditiveExpr + // | RelationalExpr '>' AdditiveExpr + // | RelationalExpr '<=' AdditiveExpr + // | RelationalExpr '>=' AdditiveExpr + // AdditiveExpr ::= MultiplicativeExpr + // | AdditiveExpr '+' MultiplicativeExpr + // | AdditiveExpr '-' MultiplicativeExpr + // MultiplicativeExpr ::= UnaryExpr + // | MultiplicativeExpr '*' UnaryExpr + // | MultiplicativeExpr 'div' UnaryExpr + // | MultiplicativeExpr 'mod' UnaryExpr + xpath_ast_node* parse_expression(int limit = 0) + { + size_t old_depth = _depth; + + if (++_depth > xpath_ast_depth_limit) + return error_rec(); + + xpath_ast_node* n = parse_path_or_unary_expression(); + if (!n) return 0; + + n = parse_expression_rec(n, limit); + + _depth = old_depth; + + return n; + } + + xpath_parser(const char_t* query, xpath_variable_set* variables, xpath_allocator* alloc, xpath_parse_result* result): _alloc(alloc), _lexer(query), _query(query), _variables(variables), _result(result), _depth(0) + { + } + + xpath_ast_node* parse() + { + xpath_ast_node* n = parse_expression(); + if (!n) return 0; + + assert(_depth == 0); + + // check if there are unparsed tokens left + if (_lexer.current() != lex_eof) + return error("Incorrect query"); + + return n; + } + + static xpath_ast_node* parse(const char_t* query, xpath_variable_set* variables, xpath_allocator* alloc, xpath_parse_result* result) + { + xpath_parser parser(query, variables, alloc, result); + + return parser.parse(); + } + }; + + struct xpath_query_impl + { + static xpath_query_impl* create() + { + void* memory = xml_memory::allocate(sizeof(xpath_query_impl)); + if (!memory) return 0; + + return new (memory) xpath_query_impl(); + } + + static void destroy(xpath_query_impl* impl) + { + // free all allocated pages + impl->alloc.release(); + + // free allocator memory (with the first page) + xml_memory::deallocate(impl); + } + + xpath_query_impl(): root(0), alloc(&block, &oom), oom(false) + { + block.next = 0; + block.capacity = sizeof(block.data); + } + + xpath_ast_node* root; + xpath_allocator alloc; + xpath_memory_block block; + bool oom; + }; + + PUGI__FN impl::xpath_ast_node* evaluate_node_set_prepare(xpath_query_impl* impl) + { + if (!impl) return 0; + + if (impl->root->rettype() != xpath_type_node_set) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return 0; + #else + xpath_parse_result res; + res.error = "Expression does not evaluate to node set"; + + throw xpath_exception(res); + #endif + } + + return impl->root; + } +PUGI__NS_END + +namespace pugi +{ +#ifndef PUGIXML_NO_EXCEPTIONS + PUGI__FN xpath_exception::xpath_exception(const xpath_parse_result& result_): _result(result_) + { + assert(_result.error); + } + + PUGI__FN const char* xpath_exception::what() const throw() + { + return _result.error; + } + + PUGI__FN const xpath_parse_result& xpath_exception::result() const + { + return _result; + } +#endif + + PUGI__FN xpath_node::xpath_node() + { + } + + PUGI__FN xpath_node::xpath_node(const xml_node& node_): _node(node_) + { + } + + PUGI__FN xpath_node::xpath_node(const xml_attribute& attribute_, const xml_node& parent_): _node(attribute_ ? parent_ : xml_node()), _attribute(attribute_) + { + } + + PUGI__FN xml_node xpath_node::node() const + { + return _attribute ? xml_node() : _node; + } + + PUGI__FN xml_attribute xpath_node::attribute() const + { + return _attribute; + } + + PUGI__FN xml_node xpath_node::parent() const + { + return _attribute ? _node : _node.parent(); + } + + PUGI__FN static void unspecified_bool_xpath_node(xpath_node***) + { + } + + PUGI__FN xpath_node::operator xpath_node::unspecified_bool_type() const + { + return (_node || _attribute) ? unspecified_bool_xpath_node : 0; + } + + PUGI__FN bool xpath_node::operator!() const + { + return !(_node || _attribute); + } + + PUGI__FN bool xpath_node::operator==(const xpath_node& n) const + { + return _node == n._node && _attribute == n._attribute; + } + + PUGI__FN bool xpath_node::operator!=(const xpath_node& n) const + { + return _node != n._node || _attribute != n._attribute; + } + +#ifdef __BORLANDC__ + PUGI__FN bool operator&&(const xpath_node& lhs, bool rhs) + { + return (bool)lhs && rhs; + } + + PUGI__FN bool operator||(const xpath_node& lhs, bool rhs) + { + return (bool)lhs || rhs; + } +#endif + + PUGI__FN void xpath_node_set::_assign(const_iterator begin_, const_iterator end_, type_t type_) + { + assert(begin_ <= end_); + + size_t size_ = static_cast(end_ - begin_); + + // use internal buffer for 0 or 1 elements, heap buffer otherwise + xpath_node* storage = (size_ <= 1) ? _storage : static_cast(impl::xml_memory::allocate(size_ * sizeof(xpath_node))); + + if (!storage) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return; + #else + throw std::bad_alloc(); + #endif + } + + // deallocate old buffer + if (_begin != _storage) + impl::xml_memory::deallocate(_begin); + + // size check is necessary because for begin_ = end_ = nullptr, memcpy is UB + if (size_) + memcpy(storage, begin_, size_ * sizeof(xpath_node)); + + _begin = storage; + _end = storage + size_; + _type = type_; + } + +#ifdef PUGIXML_HAS_MOVE + PUGI__FN void xpath_node_set::_move(xpath_node_set& rhs) PUGIXML_NOEXCEPT + { + _type = rhs._type; + _storage[0] = rhs._storage[0]; + _begin = (rhs._begin == rhs._storage) ? _storage : rhs._begin; + _end = _begin + (rhs._end - rhs._begin); + + rhs._type = type_unsorted; + rhs._begin = rhs._storage; + rhs._end = rhs._storage; + } +#endif + + PUGI__FN xpath_node_set::xpath_node_set(): _type(type_unsorted), _begin(_storage), _end(_storage) + { + } + + PUGI__FN xpath_node_set::xpath_node_set(const_iterator begin_, const_iterator end_, type_t type_): _type(type_unsorted), _begin(_storage), _end(_storage) + { + _assign(begin_, end_, type_); + } + + PUGI__FN xpath_node_set::~xpath_node_set() + { + if (_begin != _storage) + impl::xml_memory::deallocate(_begin); + } + + PUGI__FN xpath_node_set::xpath_node_set(const xpath_node_set& ns): _type(type_unsorted), _begin(_storage), _end(_storage) + { + _assign(ns._begin, ns._end, ns._type); + } + + PUGI__FN xpath_node_set& xpath_node_set::operator=(const xpath_node_set& ns) + { + if (this == &ns) return *this; + + _assign(ns._begin, ns._end, ns._type); + + return *this; + } + +#ifdef PUGIXML_HAS_MOVE + PUGI__FN xpath_node_set::xpath_node_set(xpath_node_set&& rhs) PUGIXML_NOEXCEPT: _type(type_unsorted), _begin(_storage), _end(_storage) + { + _move(rhs); + } + + PUGI__FN xpath_node_set& xpath_node_set::operator=(xpath_node_set&& rhs) PUGIXML_NOEXCEPT + { + if (this == &rhs) return *this; + + if (_begin != _storage) + impl::xml_memory::deallocate(_begin); + + _move(rhs); + + return *this; + } +#endif + + PUGI__FN xpath_node_set::type_t xpath_node_set::type() const + { + return _type; + } + + PUGI__FN size_t xpath_node_set::size() const + { + return _end - _begin; + } + + PUGI__FN bool xpath_node_set::empty() const + { + return _begin == _end; + } + + PUGI__FN const xpath_node& xpath_node_set::operator[](size_t index) const + { + assert(index < size()); + return _begin[index]; + } + + PUGI__FN xpath_node_set::const_iterator xpath_node_set::begin() const + { + return _begin; + } + + PUGI__FN xpath_node_set::const_iterator xpath_node_set::end() const + { + return _end; + } + + PUGI__FN void xpath_node_set::sort(bool reverse) + { + _type = impl::xpath_sort(_begin, _end, _type, reverse); + } + + PUGI__FN xpath_node xpath_node_set::first() const + { + return impl::xpath_first(_begin, _end, _type); + } + + PUGI__FN xpath_parse_result::xpath_parse_result(): error("Internal error"), offset(0) + { + } + + PUGI__FN xpath_parse_result::operator bool() const + { + return error == 0; + } + + PUGI__FN const char* xpath_parse_result::description() const + { + return error ? error : "No error"; + } + + PUGI__FN xpath_variable::xpath_variable(xpath_value_type type_): _type(type_), _next(0) + { + } + + PUGI__FN const char_t* xpath_variable::name() const + { + switch (_type) + { + case xpath_type_node_set: + return static_cast(this)->name; + + case xpath_type_number: + return static_cast(this)->name; + + case xpath_type_string: + return static_cast(this)->name; + + case xpath_type_boolean: + return static_cast(this)->name; + + default: + assert(false && "Invalid variable type"); // unreachable + return 0; + } + } + + PUGI__FN xpath_value_type xpath_variable::type() const + { + return _type; + } + + PUGI__FN bool xpath_variable::get_boolean() const + { + return (_type == xpath_type_boolean) ? static_cast(this)->value : false; + } + + PUGI__FN double xpath_variable::get_number() const + { + return (_type == xpath_type_number) ? static_cast(this)->value : impl::gen_nan(); + } + + PUGI__FN const char_t* xpath_variable::get_string() const + { + const char_t* value = (_type == xpath_type_string) ? static_cast(this)->value : 0; + return value ? value : PUGIXML_TEXT(""); + } + + PUGI__FN const xpath_node_set& xpath_variable::get_node_set() const + { + return (_type == xpath_type_node_set) ? static_cast(this)->value : impl::dummy_node_set; + } + + PUGI__FN bool xpath_variable::set(bool value) + { + if (_type != xpath_type_boolean) return false; + + static_cast(this)->value = value; + return true; + } + + PUGI__FN bool xpath_variable::set(double value) + { + if (_type != xpath_type_number) return false; + + static_cast(this)->value = value; + return true; + } + + PUGI__FN bool xpath_variable::set(const char_t* value) + { + if (_type != xpath_type_string) return false; + + impl::xpath_variable_string* var = static_cast(this); + + // duplicate string + size_t size = (impl::strlength(value) + 1) * sizeof(char_t); + + char_t* copy = static_cast(impl::xml_memory::allocate(size)); + if (!copy) return false; + + memcpy(copy, value, size); + + // replace old string + if (var->value) impl::xml_memory::deallocate(var->value); + var->value = copy; + + return true; + } + + PUGI__FN bool xpath_variable::set(const xpath_node_set& value) + { + if (_type != xpath_type_node_set) return false; + + static_cast(this)->value = value; + return true; + } + + PUGI__FN xpath_variable_set::xpath_variable_set() + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + _data[i] = 0; + } + + PUGI__FN xpath_variable_set::~xpath_variable_set() + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + _destroy(_data[i]); + } + + PUGI__FN xpath_variable_set::xpath_variable_set(const xpath_variable_set& rhs) + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + _data[i] = 0; + + _assign(rhs); + } + + PUGI__FN xpath_variable_set& xpath_variable_set::operator=(const xpath_variable_set& rhs) + { + if (this == &rhs) return *this; + + _assign(rhs); + + return *this; + } + +#ifdef PUGIXML_HAS_MOVE + PUGI__FN xpath_variable_set::xpath_variable_set(xpath_variable_set&& rhs) PUGIXML_NOEXCEPT + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + { + _data[i] = rhs._data[i]; + rhs._data[i] = 0; + } + } + + PUGI__FN xpath_variable_set& xpath_variable_set::operator=(xpath_variable_set&& rhs) PUGIXML_NOEXCEPT + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + { + _destroy(_data[i]); + + _data[i] = rhs._data[i]; + rhs._data[i] = 0; + } + + return *this; + } +#endif + + PUGI__FN void xpath_variable_set::_assign(const xpath_variable_set& rhs) + { + xpath_variable_set temp; + + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + if (rhs._data[i] && !_clone(rhs._data[i], &temp._data[i])) + return; + + _swap(temp); + } + + PUGI__FN void xpath_variable_set::_swap(xpath_variable_set& rhs) + { + for (size_t i = 0; i < sizeof(_data) / sizeof(_data[0]); ++i) + { + xpath_variable* chain = _data[i]; + + _data[i] = rhs._data[i]; + rhs._data[i] = chain; + } + } + + PUGI__FN xpath_variable* xpath_variable_set::_find(const char_t* name) const + { + const size_t hash_size = sizeof(_data) / sizeof(_data[0]); + size_t hash = impl::hash_string(name) % hash_size; + + // look for existing variable + for (xpath_variable* var = _data[hash]; var; var = var->_next) + if (impl::strequal(var->name(), name)) + return var; + + return 0; + } + + PUGI__FN bool xpath_variable_set::_clone(xpath_variable* var, xpath_variable** out_result) + { + xpath_variable* last = 0; + + while (var) + { + // allocate storage for new variable + xpath_variable* nvar = impl::new_xpath_variable(var->_type, var->name()); + if (!nvar) return false; + + // link the variable to the result immediately to handle failures gracefully + if (last) + last->_next = nvar; + else + *out_result = nvar; + + last = nvar; + + // copy the value; this can fail due to out-of-memory conditions + if (!impl::copy_xpath_variable(nvar, var)) return false; + + var = var->_next; + } + + return true; + } + + PUGI__FN void xpath_variable_set::_destroy(xpath_variable* var) + { + while (var) + { + xpath_variable* next = var->_next; + + impl::delete_xpath_variable(var->_type, var); + + var = next; + } + } + + PUGI__FN xpath_variable* xpath_variable_set::add(const char_t* name, xpath_value_type type) + { + const size_t hash_size = sizeof(_data) / sizeof(_data[0]); + size_t hash = impl::hash_string(name) % hash_size; + + // look for existing variable + for (xpath_variable* var = _data[hash]; var; var = var->_next) + if (impl::strequal(var->name(), name)) + return var->type() == type ? var : 0; + + // add new variable + xpath_variable* result = impl::new_xpath_variable(type, name); + + if (result) + { + result->_next = _data[hash]; + + _data[hash] = result; + } + + return result; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, bool value) + { + xpath_variable* var = add(name, xpath_type_boolean); + return var ? var->set(value) : false; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, double value) + { + xpath_variable* var = add(name, xpath_type_number); + return var ? var->set(value) : false; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, const char_t* value) + { + xpath_variable* var = add(name, xpath_type_string); + return var ? var->set(value) : false; + } + + PUGI__FN bool xpath_variable_set::set(const char_t* name, const xpath_node_set& value) + { + xpath_variable* var = add(name, xpath_type_node_set); + return var ? var->set(value) : false; + } + + PUGI__FN xpath_variable* xpath_variable_set::get(const char_t* name) + { + return _find(name); + } + + PUGI__FN const xpath_variable* xpath_variable_set::get(const char_t* name) const + { + return _find(name); + } + + PUGI__FN xpath_query::xpath_query(const char_t* query, xpath_variable_set* variables): _impl(0) + { + impl::xpath_query_impl* qimpl = impl::xpath_query_impl::create(); + + if (!qimpl) + { + #ifdef PUGIXML_NO_EXCEPTIONS + _result.error = "Out of memory"; + #else + throw std::bad_alloc(); + #endif + } + else + { + using impl::auto_deleter; // MSVC7 workaround + auto_deleter impl(qimpl, impl::xpath_query_impl::destroy); + + qimpl->root = impl::xpath_parser::parse(query, variables, &qimpl->alloc, &_result); + + if (qimpl->root) + { + qimpl->root->optimize(&qimpl->alloc); + + _impl = impl.release(); + _result.error = 0; + } + else + { + #ifdef PUGIXML_NO_EXCEPTIONS + if (qimpl->oom) _result.error = "Out of memory"; + #else + if (qimpl->oom) throw std::bad_alloc(); + throw xpath_exception(_result); + #endif + } + } + } + + PUGI__FN xpath_query::xpath_query(): _impl(0) + { + } + + PUGI__FN xpath_query::~xpath_query() + { + if (_impl) + impl::xpath_query_impl::destroy(static_cast(_impl)); + } + +#ifdef PUGIXML_HAS_MOVE + PUGI__FN xpath_query::xpath_query(xpath_query&& rhs) PUGIXML_NOEXCEPT + { + _impl = rhs._impl; + _result = rhs._result; + rhs._impl = 0; + rhs._result = xpath_parse_result(); + } + + PUGI__FN xpath_query& xpath_query::operator=(xpath_query&& rhs) PUGIXML_NOEXCEPT + { + if (this == &rhs) return *this; + + if (_impl) + impl::xpath_query_impl::destroy(static_cast(_impl)); + + _impl = rhs._impl; + _result = rhs._result; + rhs._impl = 0; + rhs._result = xpath_parse_result(); + + return *this; + } +#endif + + PUGI__FN xpath_value_type xpath_query::return_type() const + { + if (!_impl) return xpath_type_none; + + return static_cast(_impl)->root->rettype(); + } + + PUGI__FN bool xpath_query::evaluate_boolean(const xpath_node& n) const + { + if (!_impl) return false; + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + bool r = static_cast(_impl)->root->eval_boolean(c, sd.stack); + + if (sd.oom) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return false; + #else + throw std::bad_alloc(); + #endif + } + + return r; + } + + PUGI__FN double xpath_query::evaluate_number(const xpath_node& n) const + { + if (!_impl) return impl::gen_nan(); + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + double r = static_cast(_impl)->root->eval_number(c, sd.stack); + + if (sd.oom) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return impl::gen_nan(); + #else + throw std::bad_alloc(); + #endif + } + + return r; + } + +#ifndef PUGIXML_NO_STL + PUGI__FN string_t xpath_query::evaluate_string(const xpath_node& n) const + { + if (!_impl) return string_t(); + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + impl::xpath_string r = static_cast(_impl)->root->eval_string(c, sd.stack); + + if (sd.oom) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return string_t(); + #else + throw std::bad_alloc(); + #endif + } + + return string_t(r.c_str(), r.length()); + } +#endif + + PUGI__FN size_t xpath_query::evaluate_string(char_t* buffer, size_t capacity, const xpath_node& n) const + { + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + impl::xpath_string r = _impl ? static_cast(_impl)->root->eval_string(c, sd.stack) : impl::xpath_string(); + + if (sd.oom) + { + #ifdef PUGIXML_NO_EXCEPTIONS + r = impl::xpath_string(); + #else + throw std::bad_alloc(); + #endif + } + + size_t full_size = r.length() + 1; + + if (capacity > 0) + { + size_t size = (full_size < capacity) ? full_size : capacity; + assert(size > 0); + + memcpy(buffer, r.c_str(), (size - 1) * sizeof(char_t)); + buffer[size - 1] = 0; + } + + return full_size; + } + + PUGI__FN xpath_node_set xpath_query::evaluate_node_set(const xpath_node& n) const + { + impl::xpath_ast_node* root = impl::evaluate_node_set_prepare(static_cast(_impl)); + if (!root) return xpath_node_set(); + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + impl::xpath_node_set_raw r = root->eval_node_set(c, sd.stack, impl::nodeset_eval_all); + + if (sd.oom) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return xpath_node_set(); + #else + throw std::bad_alloc(); + #endif + } + + return xpath_node_set(r.begin(), r.end(), r.type()); + } + + PUGI__FN xpath_node xpath_query::evaluate_node(const xpath_node& n) const + { + impl::xpath_ast_node* root = impl::evaluate_node_set_prepare(static_cast(_impl)); + if (!root) return xpath_node(); + + impl::xpath_context c(n, 1, 1); + impl::xpath_stack_data sd; + + impl::xpath_node_set_raw r = root->eval_node_set(c, sd.stack, impl::nodeset_eval_first); + + if (sd.oom) + { + #ifdef PUGIXML_NO_EXCEPTIONS + return xpath_node(); + #else + throw std::bad_alloc(); + #endif + } + + return r.first(); + } + + PUGI__FN const xpath_parse_result& xpath_query::result() const + { + return _result; + } + + PUGI__FN static void unspecified_bool_xpath_query(xpath_query***) + { + } + + PUGI__FN xpath_query::operator xpath_query::unspecified_bool_type() const + { + return _impl ? unspecified_bool_xpath_query : 0; + } + + PUGI__FN bool xpath_query::operator!() const + { + return !_impl; + } + + PUGI__FN xpath_node xml_node::select_node(const char_t* query, xpath_variable_set* variables) const + { + xpath_query q(query, variables); + return q.evaluate_node(*this); + } + + PUGI__FN xpath_node xml_node::select_node(const xpath_query& query) const + { + return query.evaluate_node(*this); + } + + PUGI__FN xpath_node_set xml_node::select_nodes(const char_t* query, xpath_variable_set* variables) const + { + xpath_query q(query, variables); + return q.evaluate_node_set(*this); + } + + PUGI__FN xpath_node_set xml_node::select_nodes(const xpath_query& query) const + { + return query.evaluate_node_set(*this); + } + + PUGI__FN xpath_node xml_node::select_single_node(const char_t* query, xpath_variable_set* variables) const + { + xpath_query q(query, variables); + return q.evaluate_node(*this); + } + + PUGI__FN xpath_node xml_node::select_single_node(const xpath_query& query) const + { + return query.evaluate_node(*this); + } +} + +#endif + +#ifdef __BORLANDC__ +# pragma option pop +#endif + +// Intel C++ does not properly keep warning state for function templates, +// so popping warning state at the end of translation unit leads to warnings in the middle. +#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) +# pragma warning(pop) +#endif + +#if defined(_MSC_VER) && defined(__c2__) +# pragma clang diagnostic pop +#endif + +// Undefine all local macros (makes sure we're not leaking macros in header-only mode) +#undef PUGI__NO_INLINE +#undef PUGI__UNLIKELY +#undef PUGI__STATIC_ASSERT +#undef PUGI__DMC_VOLATILE +#undef PUGI__UNSIGNED_OVERFLOW +#undef PUGI__MSVC_CRT_VERSION +#undef PUGI__SNPRINTF +#undef PUGI__NS_BEGIN +#undef PUGI__NS_END +#undef PUGI__FN +#undef PUGI__FN_NO_INLINE +#undef PUGI__GETHEADER_IMPL +#undef PUGI__GETPAGE_IMPL +#undef PUGI__GETPAGE +#undef PUGI__NODETYPE +#undef PUGI__IS_CHARTYPE_IMPL +#undef PUGI__IS_CHARTYPE +#undef PUGI__IS_CHARTYPEX +#undef PUGI__ENDSWITH +#undef PUGI__SKIPWS +#undef PUGI__OPTSET +#undef PUGI__PUSHNODE +#undef PUGI__POPNODE +#undef PUGI__SCANFOR +#undef PUGI__SCANWHILE +#undef PUGI__SCANWHILE_UNROLL +#undef PUGI__ENDSEG +#undef PUGI__THROW_ERROR +#undef PUGI__CHECK_ERROR + +#endif + +/** + * Copyright (c) 2006-2020 Arseny Kapoulkine + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ diff --git a/src/Onvif/pugixml.hpp b/src/Onvif/pugixml.hpp new file mode 100644 index 00000000..71dbf91b --- /dev/null +++ b/src/Onvif/pugixml.hpp @@ -0,0 +1,1501 @@ +/** + * pugixml parser - version 1.11 + * -------------------------------------------------------- + * Copyright (C) 2006-2020, by Arseny Kapoulkine (arseny.kapoulkine@gmail.com) + * Report bugs and download new versions at https://pugixml.org/ + * + * This library is distributed under the MIT License. See notice at the end + * of this file. + * + * This work is based on the pugxml parser, which is: + * Copyright (C) 2003, by Kristen Wegner (kristen@tima.net) + */ + +#ifndef PUGIXML_VERSION +// Define version macro; evaluates to major * 1000 + minor * 10 + patch so that it's safe to use in less-than comparisons +// Note: pugixml used major * 100 + minor * 10 + patch format up until 1.9 (which had version identifier 190); starting from pugixml 1.10, the minor version number is two digits +# define PUGIXML_VERSION 1110 +#endif + +// Include user configuration file (this can define various configuration macros) +#include "pugiconfig.hpp" + +#ifndef HEADER_PUGIXML_HPP +#define HEADER_PUGIXML_HPP + +// Include stddef.h for size_t and ptrdiff_t +#include + +// Include exception header for XPath +#if !defined(PUGIXML_NO_XPATH) && !defined(PUGIXML_NO_EXCEPTIONS) +# include +#endif + +// Include STL headers +#ifndef PUGIXML_NO_STL +# include +# include +# include +#endif + +// Macro for deprecated features +#ifndef PUGIXML_DEPRECATED +# if defined(__GNUC__) +# define PUGIXML_DEPRECATED __attribute__((deprecated)) +# elif defined(_MSC_VER) && _MSC_VER >= 1300 +# define PUGIXML_DEPRECATED __declspec(deprecated) +# else +# define PUGIXML_DEPRECATED +# endif +#endif + +// If no API is defined, assume default +#ifndef PUGIXML_API +# define PUGIXML_API +#endif + +// If no API for classes is defined, assume default +#ifndef PUGIXML_CLASS +# define PUGIXML_CLASS PUGIXML_API +#endif + +// If no API for functions is defined, assume default +#ifndef PUGIXML_FUNCTION +# define PUGIXML_FUNCTION PUGIXML_API +#endif + +// If the platform is known to have long long support, enable long long functions +#ifndef PUGIXML_HAS_LONG_LONG +# if __cplusplus >= 201103 +# define PUGIXML_HAS_LONG_LONG +# elif defined(_MSC_VER) && _MSC_VER >= 1400 +# define PUGIXML_HAS_LONG_LONG +# endif +#endif + +// If the platform is known to have move semantics support, compile move ctor/operator implementation +#ifndef PUGIXML_HAS_MOVE +# if __cplusplus >= 201103 +# define PUGIXML_HAS_MOVE +# elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define PUGIXML_HAS_MOVE +# endif +#endif + +// If C++ is 2011 or higher, add 'noexcept' specifiers +#ifndef PUGIXML_NOEXCEPT +# if __cplusplus >= 201103 +# define PUGIXML_NOEXCEPT noexcept +# elif defined(_MSC_VER) && _MSC_VER >= 1900 +# define PUGIXML_NOEXCEPT noexcept +# else +# define PUGIXML_NOEXCEPT +# endif +#endif + +// Some functions can not be noexcept in compact mode +#ifdef PUGIXML_COMPACT +# define PUGIXML_NOEXCEPT_IF_NOT_COMPACT +#else +# define PUGIXML_NOEXCEPT_IF_NOT_COMPACT PUGIXML_NOEXCEPT +#endif + +// If C++ is 2011 or higher, add 'override' qualifiers +#ifndef PUGIXML_OVERRIDE +# if __cplusplus >= 201103 +# define PUGIXML_OVERRIDE override +# elif defined(_MSC_VER) && _MSC_VER >= 1700 +# define PUGIXML_OVERRIDE override +# else +# define PUGIXML_OVERRIDE +# endif +#endif + +// If C++ is 2011 or higher, use 'nullptr' +#ifndef PUGIXML_NULL +# if __cplusplus >= 201103 +# define PUGIXML_NULL nullptr +# else +# define PUGIXML_NULL 0 +# endif +#endif + +// Character interface macros +#ifdef PUGIXML_WCHAR_MODE +# define PUGIXML_TEXT(t) L ## t +# define PUGIXML_CHAR wchar_t +#else +# define PUGIXML_TEXT(t) t +# define PUGIXML_CHAR char +#endif + +namespace pugi +{ + // Character type used for all internal storage and operations; depends on PUGIXML_WCHAR_MODE + typedef PUGIXML_CHAR char_t; + +#ifndef PUGIXML_NO_STL + // String type used for operations that work with STL string; depends on PUGIXML_WCHAR_MODE + typedef std::basic_string, std::allocator > string_t; +#endif +} + +// The PugiXML namespace +namespace pugi +{ + // Tree node types + enum xml_node_type + { + node_null, // Empty (null) node handle + node_document, // A document tree's absolute root + node_element, // Element tag, i.e. '' + node_pcdata, // Plain character data, i.e. 'text' + node_cdata, // Character data, i.e. '' + node_comment, // Comment tag, i.e. '' + node_pi, // Processing instruction, i.e. '' + node_declaration, // Document declaration, i.e. '' + node_doctype // Document type declaration, i.e. '' + }; + + // Parsing options + + // Minimal parsing mode (equivalent to turning all other flags off). + // Only elements and PCDATA sections are added to the DOM tree, no text conversions are performed. + const unsigned int parse_minimal = 0x0000; + + // This flag determines if processing instructions (node_pi) are added to the DOM tree. This flag is off by default. + const unsigned int parse_pi = 0x0001; + + // This flag determines if comments (node_comment) are added to the DOM tree. This flag is off by default. + const unsigned int parse_comments = 0x0002; + + // This flag determines if CDATA sections (node_cdata) are added to the DOM tree. This flag is on by default. + const unsigned int parse_cdata = 0x0004; + + // This flag determines if plain character data (node_pcdata) that consist only of whitespace are added to the DOM tree. + // This flag is off by default; turning it on usually results in slower parsing and more memory consumption. + const unsigned int parse_ws_pcdata = 0x0008; + + // This flag determines if character and entity references are expanded during parsing. This flag is on by default. + const unsigned int parse_escapes = 0x0010; + + // This flag determines if EOL characters are normalized (converted to #xA) during parsing. This flag is on by default. + const unsigned int parse_eol = 0x0020; + + // This flag determines if attribute values are normalized using CDATA normalization rules during parsing. This flag is on by default. + const unsigned int parse_wconv_attribute = 0x0040; + + // This flag determines if attribute values are normalized using NMTOKENS normalization rules during parsing. This flag is off by default. + const unsigned int parse_wnorm_attribute = 0x0080; + + // This flag determines if document declaration (node_declaration) is added to the DOM tree. This flag is off by default. + const unsigned int parse_declaration = 0x0100; + + // This flag determines if document type declaration (node_doctype) is added to the DOM tree. This flag is off by default. + const unsigned int parse_doctype = 0x0200; + + // This flag determines if plain character data (node_pcdata) that is the only child of the parent node and that consists only + // of whitespace is added to the DOM tree. + // This flag is off by default; turning it on may result in slower parsing and more memory consumption. + const unsigned int parse_ws_pcdata_single = 0x0400; + + // This flag determines if leading and trailing whitespace is to be removed from plain character data. This flag is off by default. + const unsigned int parse_trim_pcdata = 0x0800; + + // This flag determines if plain character data that does not have a parent node is added to the DOM tree, and if an empty document + // is a valid document. This flag is off by default. + const unsigned int parse_fragment = 0x1000; + + // This flag determines if plain character data is be stored in the parent element's value. This significantly changes the structure of + // the document; this flag is only recommended for parsing documents with many PCDATA nodes in memory-constrained environments. + // This flag is off by default. + const unsigned int parse_embed_pcdata = 0x2000; + + // The default parsing mode. + // Elements, PCDATA and CDATA sections are added to the DOM tree, character/reference entities are expanded, + // End-of-Line characters are normalized, attribute values are normalized using CDATA normalization rules. + const unsigned int parse_default = parse_cdata | parse_escapes | parse_wconv_attribute | parse_eol; + + // The full parsing mode. + // Nodes of all types are added to the DOM tree, character/reference entities are expanded, + // End-of-Line characters are normalized, attribute values are normalized using CDATA normalization rules. + const unsigned int parse_full = parse_default | parse_pi | parse_comments | parse_declaration | parse_doctype; + + // These flags determine the encoding of input data for XML document + enum xml_encoding + { + encoding_auto, // Auto-detect input encoding using BOM or < / class xml_object_range + { + public: + typedef It const_iterator; + typedef It iterator; + + xml_object_range(It b, It e): _begin(b), _end(e) + { + } + + It begin() const { return _begin; } + It end() const { return _end; } + + bool empty() const { return _begin == _end; } + + private: + It _begin, _end; + }; + + // Writer interface for node printing (see xml_node::print) + class PUGIXML_CLASS xml_writer + { + public: + virtual ~xml_writer() {} + + // Write memory chunk into stream/file/whatever + virtual void write(const void* data, size_t size) = 0; + }; + + // xml_writer implementation for FILE* + class PUGIXML_CLASS xml_writer_file: public xml_writer + { + public: + // Construct writer from a FILE* object; void* is used to avoid header dependencies on stdio + xml_writer_file(void* file); + + virtual void write(const void* data, size_t size) PUGIXML_OVERRIDE; + + private: + void* file; + }; + + #ifndef PUGIXML_NO_STL + // xml_writer implementation for streams + class PUGIXML_CLASS xml_writer_stream: public xml_writer + { + public: + // Construct writer from an output stream object + xml_writer_stream(std::basic_ostream >& stream); + xml_writer_stream(std::basic_ostream >& stream); + + virtual void write(const void* data, size_t size) PUGIXML_OVERRIDE; + + private: + std::basic_ostream >* narrow_stream; + std::basic_ostream >* wide_stream; + }; + #endif + + // A light-weight handle for manipulating attributes in DOM tree + class PUGIXML_CLASS xml_attribute + { + friend class xml_attribute_iterator; + friend class xml_node; + + private: + xml_attribute_struct* _attr; + + typedef void (*unspecified_bool_type)(xml_attribute***); + + public: + // Default constructor. Constructs an empty attribute. + xml_attribute(); + + // Constructs attribute from internal pointer + explicit xml_attribute(xml_attribute_struct* attr); + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Comparison operators (compares wrapped attribute pointers) + bool operator==(const xml_attribute& r) const; + bool operator!=(const xml_attribute& r) const; + bool operator<(const xml_attribute& r) const; + bool operator>(const xml_attribute& r) const; + bool operator<=(const xml_attribute& r) const; + bool operator>=(const xml_attribute& r) const; + + // Check if attribute is empty + bool empty() const; + + // Get attribute name/value, or "" if attribute is empty + const char_t* name() const; + const char_t* value() const; + + // Get attribute value, or the default value if attribute is empty + const char_t* as_string(const char_t* def = PUGIXML_TEXT("")) const; + + // Get attribute value as a number, or the default value if conversion did not succeed or attribute is empty + int as_int(int def = 0) const; + unsigned int as_uint(unsigned int def = 0) const; + double as_double(double def = 0) const; + float as_float(float def = 0) const; + + #ifdef PUGIXML_HAS_LONG_LONG + long long as_llong(long long def = 0) const; + unsigned long long as_ullong(unsigned long long def = 0) const; + #endif + + // Get attribute value as bool (returns true if first character is in '1tTyY' set), or the default value if attribute is empty + bool as_bool(bool def = false) const; + + // Set attribute name/value (returns false if attribute is empty or there is not enough memory) + bool set_name(const char_t* rhs); + bool set_value(const char_t* rhs); + + // Set attribute value with type conversion (numbers are converted to strings, boolean is converted to "true"/"false") + bool set_value(int rhs); + bool set_value(unsigned int rhs); + bool set_value(long rhs); + bool set_value(unsigned long rhs); + bool set_value(double rhs); + bool set_value(double rhs, int precision); + bool set_value(float rhs); + bool set_value(float rhs, int precision); + bool set_value(bool rhs); + + #ifdef PUGIXML_HAS_LONG_LONG + bool set_value(long long rhs); + bool set_value(unsigned long long rhs); + #endif + + // Set attribute value (equivalent to set_value without error checking) + xml_attribute& operator=(const char_t* rhs); + xml_attribute& operator=(int rhs); + xml_attribute& operator=(unsigned int rhs); + xml_attribute& operator=(long rhs); + xml_attribute& operator=(unsigned long rhs); + xml_attribute& operator=(double rhs); + xml_attribute& operator=(float rhs); + xml_attribute& operator=(bool rhs); + + #ifdef PUGIXML_HAS_LONG_LONG + xml_attribute& operator=(long long rhs); + xml_attribute& operator=(unsigned long long rhs); + #endif + + // Get next/previous attribute in the attribute list of the parent node + xml_attribute next_attribute() const; + xml_attribute previous_attribute() const; + + // Get hash value (unique for handles to the same object) + size_t hash_value() const; + + // Get internal pointer + xml_attribute_struct* internal_object() const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xml_attribute& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xml_attribute& lhs, bool rhs); +#endif + + // A light-weight handle for manipulating nodes in DOM tree + class PUGIXML_CLASS xml_node + { + friend class xml_attribute_iterator; + friend class xml_node_iterator; + friend class xml_named_node_iterator; + + protected: + xml_node_struct* _root; + + typedef void (*unspecified_bool_type)(xml_node***); + + public: + // Default constructor. Constructs an empty node. + xml_node(); + + // Constructs node from internal pointer + explicit xml_node(xml_node_struct* p); + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Comparison operators (compares wrapped node pointers) + bool operator==(const xml_node& r) const; + bool operator!=(const xml_node& r) const; + bool operator<(const xml_node& r) const; + bool operator>(const xml_node& r) const; + bool operator<=(const xml_node& r) const; + bool operator>=(const xml_node& r) const; + + // Check if node is empty. + bool empty() const; + + // Get node type + xml_node_type type() const; + + // Get node name, or "" if node is empty or it has no name + const char_t* name() const; + + // Get node value, or "" if node is empty or it has no value + // Note: For text node.value() does not return "text"! Use child_value() or text() methods to access text inside nodes. + const char_t* value() const; + + // Get attribute list + xml_attribute first_attribute() const; + xml_attribute last_attribute() const; + + // Get children list + xml_node first_child() const; + xml_node last_child() const; + + // Get next/previous sibling in the children list of the parent node + xml_node next_sibling() const; + xml_node previous_sibling() const; + + // Get parent node + xml_node parent() const; + + // Get root of DOM tree this node belongs to + xml_node root() const; + + // Get text object for the current node + xml_text text() const; + + // Get child, attribute or next/previous sibling with the specified name + xml_node child(const char_t* name) const; + xml_attribute attribute(const char_t* name) const; + xml_node next_sibling(const char_t* name) const; + xml_node previous_sibling(const char_t* name) const; + + // Get attribute, starting the search from a hint (and updating hint so that searching for a sequence of attributes is fast) + xml_attribute attribute(const char_t* name, xml_attribute& hint) const; + + // Get child value of current node; that is, value of the first child node of type PCDATA/CDATA + const char_t* child_value() const; + + // Get child value of child with specified name. Equivalent to child(name).child_value(). + const char_t* child_value(const char_t* name) const; + + // Set node name/value (returns false if node is empty, there is not enough memory, or node can not have name/value) + bool set_name(const char_t* rhs); + bool set_value(const char_t* rhs); + + // Add attribute with specified name. Returns added attribute, or empty attribute on errors. + xml_attribute append_attribute(const char_t* name); + xml_attribute prepend_attribute(const char_t* name); + xml_attribute insert_attribute_after(const char_t* name, const xml_attribute& attr); + xml_attribute insert_attribute_before(const char_t* name, const xml_attribute& attr); + + // Add a copy of the specified attribute. Returns added attribute, or empty attribute on errors. + xml_attribute append_copy(const xml_attribute& proto); + xml_attribute prepend_copy(const xml_attribute& proto); + xml_attribute insert_copy_after(const xml_attribute& proto, const xml_attribute& attr); + xml_attribute insert_copy_before(const xml_attribute& proto, const xml_attribute& attr); + + // Add child node with specified type. Returns added node, or empty node on errors. + xml_node append_child(xml_node_type type = node_element); + xml_node prepend_child(xml_node_type type = node_element); + xml_node insert_child_after(xml_node_type type, const xml_node& node); + xml_node insert_child_before(xml_node_type type, const xml_node& node); + + // Add child element with specified name. Returns added node, or empty node on errors. + xml_node append_child(const char_t* name); + xml_node prepend_child(const char_t* name); + xml_node insert_child_after(const char_t* name, const xml_node& node); + xml_node insert_child_before(const char_t* name, const xml_node& node); + + // Add a copy of the specified node as a child. Returns added node, or empty node on errors. + xml_node append_copy(const xml_node& proto); + xml_node prepend_copy(const xml_node& proto); + xml_node insert_copy_after(const xml_node& proto, const xml_node& node); + xml_node insert_copy_before(const xml_node& proto, const xml_node& node); + + // Move the specified node to become a child of this node. Returns moved node, or empty node on errors. + xml_node append_move(const xml_node& moved); + xml_node prepend_move(const xml_node& moved); + xml_node insert_move_after(const xml_node& moved, const xml_node& node); + xml_node insert_move_before(const xml_node& moved, const xml_node& node); + + // Remove specified attribute + bool remove_attribute(const xml_attribute& a); + bool remove_attribute(const char_t* name); + + // Remove all attributes + bool remove_attributes(); + + // Remove specified child + bool remove_child(const xml_node& n); + bool remove_child(const char_t* name); + + // Remove all children + bool remove_children(); + + // Parses buffer as an XML document fragment and appends all nodes as children of the current node. + // Copies/converts the buffer, so it may be deleted or changed after the function returns. + // Note: append_buffer allocates memory that has the lifetime of the owning document; removing the appended nodes does not immediately reclaim that memory. + xml_parse_result append_buffer(const void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Find attribute using predicate. Returns first attribute for which predicate returned true. + template xml_attribute find_attribute(Predicate pred) const + { + if (!_root) return xml_attribute(); + + for (xml_attribute attrib = first_attribute(); attrib; attrib = attrib.next_attribute()) + if (pred(attrib)) + return attrib; + + return xml_attribute(); + } + + // Find child node using predicate. Returns first child for which predicate returned true. + template xml_node find_child(Predicate pred) const + { + if (!_root) return xml_node(); + + for (xml_node node = first_child(); node; node = node.next_sibling()) + if (pred(node)) + return node; + + return xml_node(); + } + + // Find node from subtree using predicate. Returns first node from subtree (depth-first), for which predicate returned true. + template xml_node find_node(Predicate pred) const + { + if (!_root) return xml_node(); + + xml_node cur = first_child(); + + while (cur._root && cur._root != _root) + { + if (pred(cur)) return cur; + + if (cur.first_child()) cur = cur.first_child(); + else if (cur.next_sibling()) cur = cur.next_sibling(); + else + { + while (!cur.next_sibling() && cur._root != _root) cur = cur.parent(); + + if (cur._root != _root) cur = cur.next_sibling(); + } + } + + return xml_node(); + } + + // Find child node by attribute name/value + xml_node find_child_by_attribute(const char_t* name, const char_t* attr_name, const char_t* attr_value) const; + xml_node find_child_by_attribute(const char_t* attr_name, const char_t* attr_value) const; + + #ifndef PUGIXML_NO_STL + // Get the absolute node path from root as a text string. + string_t path(char_t delimiter = '/') const; + #endif + + // Search for a node by path consisting of node names and . or .. elements. + xml_node first_element_by_path(const char_t* path, char_t delimiter = '/') const; + + // Recursively traverse subtree with xml_tree_walker + bool traverse(xml_tree_walker& walker); + + #ifndef PUGIXML_NO_XPATH + // Select single node by evaluating XPath query. Returns first node from the resulting node set. + xpath_node select_node(const char_t* query, xpath_variable_set* variables = PUGIXML_NULL) const; + xpath_node select_node(const xpath_query& query) const; + + // Select node set by evaluating XPath query + xpath_node_set select_nodes(const char_t* query, xpath_variable_set* variables = PUGIXML_NULL) const; + xpath_node_set select_nodes(const xpath_query& query) const; + + // (deprecated: use select_node instead) Select single node by evaluating XPath query. + PUGIXML_DEPRECATED xpath_node select_single_node(const char_t* query, xpath_variable_set* variables = PUGIXML_NULL) const; + PUGIXML_DEPRECATED xpath_node select_single_node(const xpath_query& query) const; + + #endif + + // Print subtree using a writer object + void print(xml_writer& writer, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto, unsigned int depth = 0) const; + + #ifndef PUGIXML_NO_STL + // Print subtree to stream + void print(std::basic_ostream >& os, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto, unsigned int depth = 0) const; + void print(std::basic_ostream >& os, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, unsigned int depth = 0) const; + #endif + + // Child nodes iterators + typedef xml_node_iterator iterator; + + iterator begin() const; + iterator end() const; + + // Attribute iterators + typedef xml_attribute_iterator attribute_iterator; + + attribute_iterator attributes_begin() const; + attribute_iterator attributes_end() const; + + // Range-based for support + xml_object_range children() const; + xml_object_range children(const char_t* name) const; + xml_object_range attributes() const; + + // Get node offset in parsed file/string (in char_t units) for debugging purposes + ptrdiff_t offset_debug() const; + + // Get hash value (unique for handles to the same object) + size_t hash_value() const; + + // Get internal pointer + xml_node_struct* internal_object() const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xml_node& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xml_node& lhs, bool rhs); +#endif + + // A helper for working with text inside PCDATA nodes + class PUGIXML_CLASS xml_text + { + friend class xml_node; + + xml_node_struct* _root; + + typedef void (*unspecified_bool_type)(xml_text***); + + explicit xml_text(xml_node_struct* root); + + xml_node_struct* _data_new(); + xml_node_struct* _data() const; + + public: + // Default constructor. Constructs an empty object. + xml_text(); + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Check if text object is empty + bool empty() const; + + // Get text, or "" if object is empty + const char_t* get() const; + + // Get text, or the default value if object is empty + const char_t* as_string(const char_t* def = PUGIXML_TEXT("")) const; + + // Get text as a number, or the default value if conversion did not succeed or object is empty + int as_int(int def = 0) const; + unsigned int as_uint(unsigned int def = 0) const; + double as_double(double def = 0) const; + float as_float(float def = 0) const; + + #ifdef PUGIXML_HAS_LONG_LONG + long long as_llong(long long def = 0) const; + unsigned long long as_ullong(unsigned long long def = 0) const; + #endif + + // Get text as bool (returns true if first character is in '1tTyY' set), or the default value if object is empty + bool as_bool(bool def = false) const; + + // Set text (returns false if object is empty or there is not enough memory) + bool set(const char_t* rhs); + + // Set text with type conversion (numbers are converted to strings, boolean is converted to "true"/"false") + bool set(int rhs); + bool set(unsigned int rhs); + bool set(long rhs); + bool set(unsigned long rhs); + bool set(double rhs); + bool set(double rhs, int precision); + bool set(float rhs); + bool set(float rhs, int precision); + bool set(bool rhs); + + #ifdef PUGIXML_HAS_LONG_LONG + bool set(long long rhs); + bool set(unsigned long long rhs); + #endif + + // Set text (equivalent to set without error checking) + xml_text& operator=(const char_t* rhs); + xml_text& operator=(int rhs); + xml_text& operator=(unsigned int rhs); + xml_text& operator=(long rhs); + xml_text& operator=(unsigned long rhs); + xml_text& operator=(double rhs); + xml_text& operator=(float rhs); + xml_text& operator=(bool rhs); + + #ifdef PUGIXML_HAS_LONG_LONG + xml_text& operator=(long long rhs); + xml_text& operator=(unsigned long long rhs); + #endif + + // Get the data node (node_pcdata or node_cdata) for this object + xml_node data() const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xml_text& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xml_text& lhs, bool rhs); +#endif + + // Child node iterator (a bidirectional iterator over a collection of xml_node) + class PUGIXML_CLASS xml_node_iterator + { + friend class xml_node; + + private: + mutable xml_node _wrap; + xml_node _parent; + + xml_node_iterator(xml_node_struct* ref, xml_node_struct* parent); + + public: + // Iterator traits + typedef ptrdiff_t difference_type; + typedef xml_node value_type; + typedef xml_node* pointer; + typedef xml_node& reference; + + #ifndef PUGIXML_NO_STL + typedef std::bidirectional_iterator_tag iterator_category; + #endif + + // Default constructor + xml_node_iterator(); + + // Construct an iterator which points to the specified node + xml_node_iterator(const xml_node& node); + + // Iterator operators + bool operator==(const xml_node_iterator& rhs) const; + bool operator!=(const xml_node_iterator& rhs) const; + + xml_node& operator*() const; + xml_node* operator->() const; + + xml_node_iterator& operator++(); + xml_node_iterator operator++(int); + + xml_node_iterator& operator--(); + xml_node_iterator operator--(int); + }; + + // Attribute iterator (a bidirectional iterator over a collection of xml_attribute) + class PUGIXML_CLASS xml_attribute_iterator + { + friend class xml_node; + + private: + mutable xml_attribute _wrap; + xml_node _parent; + + xml_attribute_iterator(xml_attribute_struct* ref, xml_node_struct* parent); + + public: + // Iterator traits + typedef ptrdiff_t difference_type; + typedef xml_attribute value_type; + typedef xml_attribute* pointer; + typedef xml_attribute& reference; + + #ifndef PUGIXML_NO_STL + typedef std::bidirectional_iterator_tag iterator_category; + #endif + + // Default constructor + xml_attribute_iterator(); + + // Construct an iterator which points to the specified attribute + xml_attribute_iterator(const xml_attribute& attr, const xml_node& parent); + + // Iterator operators + bool operator==(const xml_attribute_iterator& rhs) const; + bool operator!=(const xml_attribute_iterator& rhs) const; + + xml_attribute& operator*() const; + xml_attribute* operator->() const; + + xml_attribute_iterator& operator++(); + xml_attribute_iterator operator++(int); + + xml_attribute_iterator& operator--(); + xml_attribute_iterator operator--(int); + }; + + // Named node range helper + class PUGIXML_CLASS xml_named_node_iterator + { + friend class xml_node; + + public: + // Iterator traits + typedef ptrdiff_t difference_type; + typedef xml_node value_type; + typedef xml_node* pointer; + typedef xml_node& reference; + + #ifndef PUGIXML_NO_STL + typedef std::bidirectional_iterator_tag iterator_category; + #endif + + // Default constructor + xml_named_node_iterator(); + + // Construct an iterator which points to the specified node + xml_named_node_iterator(const xml_node& node, const char_t* name); + + // Iterator operators + bool operator==(const xml_named_node_iterator& rhs) const; + bool operator!=(const xml_named_node_iterator& rhs) const; + + xml_node& operator*() const; + xml_node* operator->() const; + + xml_named_node_iterator& operator++(); + xml_named_node_iterator operator++(int); + + xml_named_node_iterator& operator--(); + xml_named_node_iterator operator--(int); + + private: + mutable xml_node _wrap; + xml_node _parent; + const char_t* _name; + + xml_named_node_iterator(xml_node_struct* ref, xml_node_struct* parent, const char_t* name); + }; + + // Abstract tree walker class (see xml_node::traverse) + class PUGIXML_CLASS xml_tree_walker + { + friend class xml_node; + + private: + int _depth; + + protected: + // Get current traversal depth + int depth() const; + + public: + xml_tree_walker(); + virtual ~xml_tree_walker(); + + // Callback that is called when traversal begins + virtual bool begin(xml_node& node); + + // Callback that is called for each node traversed + virtual bool for_each(xml_node& node) = 0; + + // Callback that is called when traversal ends + virtual bool end(xml_node& node); + }; + + // Parsing status, returned as part of xml_parse_result object + enum xml_parse_status + { + status_ok = 0, // No error + + status_file_not_found, // File was not found during load_file() + status_io_error, // Error reading from file/stream + status_out_of_memory, // Could not allocate memory + status_internal_error, // Internal error occurred + + status_unrecognized_tag, // Parser could not determine tag type + + status_bad_pi, // Parsing error occurred while parsing document declaration/processing instruction + status_bad_comment, // Parsing error occurred while parsing comment + status_bad_cdata, // Parsing error occurred while parsing CDATA section + status_bad_doctype, // Parsing error occurred while parsing document type declaration + status_bad_pcdata, // Parsing error occurred while parsing PCDATA section + status_bad_start_element, // Parsing error occurred while parsing start element tag + status_bad_attribute, // Parsing error occurred while parsing element attribute + status_bad_end_element, // Parsing error occurred while parsing end element tag + status_end_element_mismatch,// There was a mismatch of start-end tags (closing tag had incorrect name, some tag was not closed or there was an excessive closing tag) + + status_append_invalid_root, // Unable to append nodes since root type is not node_element or node_document (exclusive to xml_node::append_buffer) + + status_no_document_element // Parsing resulted in a document without element nodes + }; + + // Parsing result + struct PUGIXML_CLASS xml_parse_result + { + // Parsing status (see xml_parse_status) + xml_parse_status status; + + // Last parsed offset (in char_t units from start of input data) + ptrdiff_t offset; + + // Source document encoding + xml_encoding encoding; + + // Default constructor, initializes object to failed state + xml_parse_result(); + + // Cast to bool operator + operator bool() const; + + // Get error description + const char* description() const; + }; + + // Document class (DOM tree root) + class PUGIXML_CLASS xml_document: public xml_node + { + private: + char_t* _buffer; + + char _memory[192]; + + // Non-copyable semantics + xml_document(const xml_document&); + xml_document& operator=(const xml_document&); + + void _create(); + void _destroy(); + void _move(xml_document& rhs) PUGIXML_NOEXCEPT_IF_NOT_COMPACT; + + public: + // Default constructor, makes empty document + xml_document(); + + // Destructor, invalidates all node/attribute handles to this document + ~xml_document(); + + #ifdef PUGIXML_HAS_MOVE + // Move semantics support + xml_document(xml_document&& rhs) PUGIXML_NOEXCEPT_IF_NOT_COMPACT; + xml_document& operator=(xml_document&& rhs) PUGIXML_NOEXCEPT_IF_NOT_COMPACT; + #endif + + // Removes all nodes, leaving the empty document + void reset(); + + // Removes all nodes, then copies the entire contents of the specified document + void reset(const xml_document& proto); + + #ifndef PUGIXML_NO_STL + // Load document from stream. + xml_parse_result load(std::basic_istream >& stream, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + xml_parse_result load(std::basic_istream >& stream, unsigned int options = parse_default); + #endif + + // (deprecated: use load_string instead) Load document from zero-terminated string. No encoding conversions are applied. + PUGIXML_DEPRECATED xml_parse_result load(const char_t* contents, unsigned int options = parse_default); + + // Load document from zero-terminated string. No encoding conversions are applied. + xml_parse_result load_string(const char_t* contents, unsigned int options = parse_default); + + // Load document from file + xml_parse_result load_file(const char* path, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + xml_parse_result load_file(const wchar_t* path, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Load document from buffer. Copies/converts the buffer, so it may be deleted or changed after the function returns. + xml_parse_result load_buffer(const void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Load document from buffer, using the buffer for in-place parsing (the buffer is modified and used for storage of document data). + // You should ensure that buffer data will persist throughout the document's lifetime, and free the buffer memory manually once document is destroyed. + xml_parse_result load_buffer_inplace(void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Load document from buffer, using the buffer for in-place parsing (the buffer is modified and used for storage of document data). + // You should allocate the buffer with pugixml allocation function; document will free the buffer when it is no longer needed (you can't use it anymore). + xml_parse_result load_buffer_inplace_own(void* contents, size_t size, unsigned int options = parse_default, xml_encoding encoding = encoding_auto); + + // Save XML document to writer (semantics is slightly different from xml_node::print, see documentation for details). + void save(xml_writer& writer, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + + #ifndef PUGIXML_NO_STL + // Save XML document to stream (semantics is slightly different from xml_node::print, see documentation for details). + void save(std::basic_ostream >& stream, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + void save(std::basic_ostream >& stream, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default) const; + #endif + + // Save XML to file + bool save_file(const char* path, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + bool save_file(const wchar_t* path, const char_t* indent = PUGIXML_TEXT("\t"), unsigned int flags = format_default, xml_encoding encoding = encoding_auto) const; + + // Get document element + xml_node document_element() const; + }; + +#ifndef PUGIXML_NO_XPATH + // XPath query return type + enum xpath_value_type + { + xpath_type_none, // Unknown type (query failed to compile) + xpath_type_node_set, // Node set (xpath_node_set) + xpath_type_number, // Number + xpath_type_string, // String + xpath_type_boolean // Boolean + }; + + // XPath parsing result + struct PUGIXML_CLASS xpath_parse_result + { + // Error message (0 if no error) + const char* error; + + // Last parsed offset (in char_t units from string start) + ptrdiff_t offset; + + // Default constructor, initializes object to failed state + xpath_parse_result(); + + // Cast to bool operator + operator bool() const; + + // Get error description + const char* description() const; + }; + + // A single XPath variable + class PUGIXML_CLASS xpath_variable + { + friend class xpath_variable_set; + + protected: + xpath_value_type _type; + xpath_variable* _next; + + xpath_variable(xpath_value_type type); + + // Non-copyable semantics + xpath_variable(const xpath_variable&); + xpath_variable& operator=(const xpath_variable&); + + public: + // Get variable name + const char_t* name() const; + + // Get variable type + xpath_value_type type() const; + + // Get variable value; no type conversion is performed, default value (false, NaN, empty string, empty node set) is returned on type mismatch error + bool get_boolean() const; + double get_number() const; + const char_t* get_string() const; + const xpath_node_set& get_node_set() const; + + // Set variable value; no type conversion is performed, false is returned on type mismatch error + bool set(bool value); + bool set(double value); + bool set(const char_t* value); + bool set(const xpath_node_set& value); + }; + + // A set of XPath variables + class PUGIXML_CLASS xpath_variable_set + { + private: + xpath_variable* _data[64]; + + void _assign(const xpath_variable_set& rhs); + void _swap(xpath_variable_set& rhs); + + xpath_variable* _find(const char_t* name) const; + + static bool _clone(xpath_variable* var, xpath_variable** out_result); + static void _destroy(xpath_variable* var); + + public: + // Default constructor/destructor + xpath_variable_set(); + ~xpath_variable_set(); + + // Copy constructor/assignment operator + xpath_variable_set(const xpath_variable_set& rhs); + xpath_variable_set& operator=(const xpath_variable_set& rhs); + + #ifdef PUGIXML_HAS_MOVE + // Move semantics support + xpath_variable_set(xpath_variable_set&& rhs) PUGIXML_NOEXCEPT; + xpath_variable_set& operator=(xpath_variable_set&& rhs) PUGIXML_NOEXCEPT; + #endif + + // Add a new variable or get the existing one, if the types match + xpath_variable* add(const char_t* name, xpath_value_type type); + + // Set value of an existing variable; no type conversion is performed, false is returned if there is no such variable or if types mismatch + bool set(const char_t* name, bool value); + bool set(const char_t* name, double value); + bool set(const char_t* name, const char_t* value); + bool set(const char_t* name, const xpath_node_set& value); + + // Get existing variable by name + xpath_variable* get(const char_t* name); + const xpath_variable* get(const char_t* name) const; + }; + + // A compiled XPath query object + class PUGIXML_CLASS xpath_query + { + private: + void* _impl; + xpath_parse_result _result; + + typedef void (*unspecified_bool_type)(xpath_query***); + + // Non-copyable semantics + xpath_query(const xpath_query&); + xpath_query& operator=(const xpath_query&); + + public: + // Construct a compiled object from XPath expression. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws xpath_exception on compilation errors. + explicit xpath_query(const char_t* query, xpath_variable_set* variables = PUGIXML_NULL); + + // Constructor + xpath_query(); + + // Destructor + ~xpath_query(); + + #ifdef PUGIXML_HAS_MOVE + // Move semantics support + xpath_query(xpath_query&& rhs) PUGIXML_NOEXCEPT; + xpath_query& operator=(xpath_query&& rhs) PUGIXML_NOEXCEPT; + #endif + + // Get query expression return type + xpath_value_type return_type() const; + + // Evaluate expression as boolean value in the specified context; performs type conversion if necessary. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + bool evaluate_boolean(const xpath_node& n) const; + + // Evaluate expression as double value in the specified context; performs type conversion if necessary. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + double evaluate_number(const xpath_node& n) const; + + #ifndef PUGIXML_NO_STL + // Evaluate expression as string value in the specified context; performs type conversion if necessary. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + string_t evaluate_string(const xpath_node& n) const; + #endif + + // Evaluate expression as string value in the specified context; performs type conversion if necessary. + // At most capacity characters are written to the destination buffer, full result size is returned (includes terminating zero). + // If PUGIXML_NO_EXCEPTIONS is not defined, throws std::bad_alloc on out of memory errors. + // If PUGIXML_NO_EXCEPTIONS is defined, returns empty set instead. + size_t evaluate_string(char_t* buffer, size_t capacity, const xpath_node& n) const; + + // Evaluate expression as node set in the specified context. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws xpath_exception on type mismatch and std::bad_alloc on out of memory errors. + // If PUGIXML_NO_EXCEPTIONS is defined, returns empty node set instead. + xpath_node_set evaluate_node_set(const xpath_node& n) const; + + // Evaluate expression as node set in the specified context. + // Return first node in document order, or empty node if node set is empty. + // If PUGIXML_NO_EXCEPTIONS is not defined, throws xpath_exception on type mismatch and std::bad_alloc on out of memory errors. + // If PUGIXML_NO_EXCEPTIONS is defined, returns empty node instead. + xpath_node evaluate_node(const xpath_node& n) const; + + // Get parsing result (used to get compilation errors in PUGIXML_NO_EXCEPTIONS mode) + const xpath_parse_result& result() const; + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + }; + + #ifndef PUGIXML_NO_EXCEPTIONS + #if defined(_MSC_VER) + // C4275 can be ignored in Visual C++ if you are deriving + // from a type in the Standard C++ Library + #pragma warning(push) + #pragma warning(disable: 4275) + #endif + // XPath exception class + class PUGIXML_CLASS xpath_exception: public std::exception + { + private: + xpath_parse_result _result; + + public: + // Construct exception from parse result + explicit xpath_exception(const xpath_parse_result& result); + + // Get error message + virtual const char* what() const throw() PUGIXML_OVERRIDE; + + // Get parse result + const xpath_parse_result& result() const; + }; + #if defined(_MSC_VER) + #pragma warning(pop) + #endif + #endif + + // XPath node class (either xml_node or xml_attribute) + class PUGIXML_CLASS xpath_node + { + private: + xml_node _node; + xml_attribute _attribute; + + typedef void (*unspecified_bool_type)(xpath_node***); + + public: + // Default constructor; constructs empty XPath node + xpath_node(); + + // Construct XPath node from XML node/attribute + xpath_node(const xml_node& node); + xpath_node(const xml_attribute& attribute, const xml_node& parent); + + // Get node/attribute, if any + xml_node node() const; + xml_attribute attribute() const; + + // Get parent of contained node/attribute + xml_node parent() const; + + // Safe bool conversion operator + operator unspecified_bool_type() const; + + // Borland C++ workaround + bool operator!() const; + + // Comparison operators + bool operator==(const xpath_node& n) const; + bool operator!=(const xpath_node& n) const; + }; + +#ifdef __BORLANDC__ + // Borland C++ workaround + bool PUGIXML_FUNCTION operator&&(const xpath_node& lhs, bool rhs); + bool PUGIXML_FUNCTION operator||(const xpath_node& lhs, bool rhs); +#endif + + // A fixed-size collection of XPath nodes + class PUGIXML_CLASS xpath_node_set + { + public: + // Collection type + enum type_t + { + type_unsorted, // Not ordered + type_sorted, // Sorted by document order (ascending) + type_sorted_reverse // Sorted by document order (descending) + }; + + // Constant iterator type + typedef const xpath_node* const_iterator; + + // We define non-constant iterator to be the same as constant iterator so that various generic algorithms (i.e. boost foreach) work + typedef const xpath_node* iterator; + + // Default constructor. Constructs empty set. + xpath_node_set(); + + // Constructs a set from iterator range; data is not checked for duplicates and is not sorted according to provided type, so be careful + xpath_node_set(const_iterator begin, const_iterator end, type_t type = type_unsorted); + + // Destructor + ~xpath_node_set(); + + // Copy constructor/assignment operator + xpath_node_set(const xpath_node_set& ns); + xpath_node_set& operator=(const xpath_node_set& ns); + + #ifdef PUGIXML_HAS_MOVE + // Move semantics support + xpath_node_set(xpath_node_set&& rhs) PUGIXML_NOEXCEPT; + xpath_node_set& operator=(xpath_node_set&& rhs) PUGIXML_NOEXCEPT; + #endif + + // Get collection type + type_t type() const; + + // Get collection size + size_t size() const; + + // Indexing operator + const xpath_node& operator[](size_t index) const; + + // Collection iterators + const_iterator begin() const; + const_iterator end() const; + + // Sort the collection in ascending/descending order by document order + void sort(bool reverse = false); + + // Get first node in the collection by document order + xpath_node first() const; + + // Check if collection is empty + bool empty() const; + + private: + type_t _type; + + xpath_node _storage[1]; + + xpath_node* _begin; + xpath_node* _end; + + void _assign(const_iterator begin, const_iterator end, type_t type); + void _move(xpath_node_set& rhs) PUGIXML_NOEXCEPT; + }; +#endif + +#ifndef PUGIXML_NO_STL + // Convert wide string to UTF8 + std::basic_string, std::allocator > PUGIXML_FUNCTION as_utf8(const wchar_t* str); + std::basic_string, std::allocator > PUGIXML_FUNCTION as_utf8(const std::basic_string, std::allocator >& str); + + // Convert UTF8 to wide string + std::basic_string, std::allocator > PUGIXML_FUNCTION as_wide(const char* str); + std::basic_string, std::allocator > PUGIXML_FUNCTION as_wide(const std::basic_string, std::allocator >& str); +#endif + + // Memory allocation function interface; returns pointer to allocated memory or NULL on failure + typedef void* (*allocation_function)(size_t size); + + // Memory deallocation function interface + typedef void (*deallocation_function)(void* ptr); + + // Override default memory management functions. All subsequent allocations/deallocations will be performed via supplied functions. + void PUGIXML_FUNCTION set_memory_management_functions(allocation_function allocate, deallocation_function deallocate); + + // Get current memory management functions + allocation_function PUGIXML_FUNCTION get_memory_allocation_function(); + deallocation_function PUGIXML_FUNCTION get_memory_deallocation_function(); +} + +#if !defined(PUGIXML_NO_STL) && (defined(_MSC_VER) || defined(__ICC)) +namespace std +{ + // Workarounds for (non-standard) iterator category detection for older versions (MSVC7/IC8 and earlier) + std::bidirectional_iterator_tag PUGIXML_FUNCTION _Iter_cat(const pugi::xml_node_iterator&); + std::bidirectional_iterator_tag PUGIXML_FUNCTION _Iter_cat(const pugi::xml_attribute_iterator&); + std::bidirectional_iterator_tag PUGIXML_FUNCTION _Iter_cat(const pugi::xml_named_node_iterator&); +} +#endif + +#if !defined(PUGIXML_NO_STL) && defined(__SUNPRO_CC) +namespace std +{ + // Workarounds for (non-standard) iterator category detection + std::bidirectional_iterator_tag PUGIXML_FUNCTION __iterator_category(const pugi::xml_node_iterator&); + std::bidirectional_iterator_tag PUGIXML_FUNCTION __iterator_category(const pugi::xml_attribute_iterator&); + std::bidirectional_iterator_tag PUGIXML_FUNCTION __iterator_category(const pugi::xml_named_node_iterator&); +} +#endif + +#endif + +// Make sure implementation is included in header-only mode +// Use macro expansion in #include to work around QMake (QTBUG-11923) +#if defined(PUGIXML_HEADER_ONLY) && !defined(PUGIXML_SOURCE) +# define PUGIXML_SOURCE "pugixml.cpp" +# include PUGIXML_SOURCE +#endif + +/** + * Copyright (c) 2006-2020 Arseny Kapoulkine + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ diff --git a/src/Player/MediaPlayer.cpp b/src/Player/MediaPlayer.cpp index 8ed4cc2f..1a36147f 100644 --- a/src/Player/MediaPlayer.cpp +++ b/src/Player/MediaPlayer.cpp @@ -10,6 +10,7 @@ #include #include "MediaPlayer.h" +#include "Common/config.h" using namespace std; using namespace toolkit; @@ -36,7 +37,7 @@ static void setOnCreateSocket_l(const std::shared_ptr &delegate, con } void MediaPlayer::play(const string &url) { - _delegate = PlayerBase::createPlayer(_poller, url); + _delegate = PlayerBase::createPlayer(_poller, url, (*this)[Client::kSchema]); assert(_delegate); setOnCreateSocket_l(_delegate, _on_create_socket); _delegate->setOnShutdown(_on_shutdown); diff --git a/src/Player/MediaPlayer.h b/src/Player/MediaPlayer.h index ca187019..2d43b6ed 100644 --- a/src/Player/MediaPlayer.h +++ b/src/Player/MediaPlayer.h @@ -26,6 +26,7 @@ public: void play(const std::string &url) override; toolkit::EventPoller::Ptr getPoller(); void setOnCreateSocket(toolkit::Socket::onCreateSocket cb); + const PlayerBase::Ptr& getDelegate() const { return _delegate; } private: toolkit::EventPoller::Ptr _poller; diff --git a/src/Player/PlayerBase.cpp b/src/Player/PlayerBase.cpp index 1c8cc501..b761e61c 100644 --- a/src/Player/PlayerBase.cpp +++ b/src/Player/PlayerBase.cpp @@ -16,15 +16,17 @@ #include "Http/HlsPlayer.h" #include "Http/TsPlayerImp.h" #ifdef ENABLE_SRT -#include "Srt/SrtPlayerImp.h" +#include "../srt/SrtPlayerImp.h" #endif // ENABLE_SRT - +#ifdef ENABLE_WEBRTC +#include "../webrtc/WebRtcProxyPlayerImp.h" +#endif // ENABLE_WEBRTC using namespace std; using namespace toolkit; namespace mediakit { -PlayerBase::Ptr PlayerBase::createPlayer(const EventPoller::Ptr &in_poller, const string &url_in) { +PlayerBase::Ptr PlayerBase::createPlayer(const EventPoller::Ptr &in_poller, const string &url_in, const std::string &schema) { auto poller = in_poller ? in_poller : EventPollerPool::Instance().getPoller(); std::weak_ptr weak_poller = poller; auto release_func = [weak_poller](PlayerBase *ptr) { @@ -37,7 +39,13 @@ PlayerBase::Ptr PlayerBase::createPlayer(const EventPoller::Ptr &in_poller, cons delete ptr; } }; + string url = url_in; + trim(url); + if (url.empty()) { + throw std::invalid_argument("invalid play url: " + url_in); + } + string prefix = findSubString(url.data(), NULL, "://"); auto pos = url.find('?'); if (pos != string::npos) { @@ -62,13 +70,13 @@ PlayerBase::Ptr PlayerBase::createPlayer(const EventPoller::Ptr &in_poller, cons return PlayerBase::Ptr(new RtmpPlayerImp(poller), release_func); } if ((strcasecmp("http", prefix.data()) == 0 || strcasecmp("https", prefix.data()) == 0)) { - if (end_with(url, ".m3u8") || end_with(url_in, ".m3u8")) { + if (end_with(url, ".m3u8") || end_with(url_in, ".m3u8") || schema == "hls") { return PlayerBase::Ptr(new HlsPlayerImp(poller), release_func); } - if (end_with(url, ".ts") || end_with(url_in, ".ts")) { + if (end_with(url, ".ts") || end_with(url_in, ".ts") || schema == "ts") { return PlayerBase::Ptr(new TsPlayerImp(poller), release_func); } - if (end_with(url, ".flv") || end_with(url_in, ".flv")) { + if (end_with(url, ".flv") || end_with(url_in, ".flv") || schema == "flv") { return PlayerBase::Ptr(new FlvPlayerImp(poller), release_func); } } @@ -78,6 +86,11 @@ PlayerBase::Ptr PlayerBase::createPlayer(const EventPoller::Ptr &in_poller, cons return PlayerBase::Ptr(new SrtPlayerImp(poller), release_func); } #endif//ENABLE_SRT +#ifdef ENABLE_WEBRTC + if ((strcasecmp("webrtc", prefix.data()) == 0 || strcasecmp("webrtcs", prefix.data()) == 0)) { + return PlayerBase::Ptr(new WebRtcProxyPlayerImp(poller), release_func); + } +#endif//ENABLE_WEBRTC throw std::invalid_argument("not supported play schema:" + url_in); } diff --git a/src/Player/PlayerBase.h b/src/Player/PlayerBase.h index dc033948..621aa179 100644 --- a/src/Player/PlayerBase.h +++ b/src/Player/PlayerBase.h @@ -21,15 +21,28 @@ #include "Common/MediaSink.h" #include "Extension/Frame.h" #include "Extension/Track.h" +#include "Common/config.h" +#include "Common/Parser.h" namespace mediakit { +template +void addCustomHeader(Type *c) { + auto &custom_header = (*c)[Client::kCustomHeader]; + if (!custom_header.empty()) { + auto args = mediakit::Parser::parseArgs(custom_header); + for (auto &pr : args) { + c->addHeader(pr.first, pr.second); + } + } +} + class PlayerBase : public TrackSource, public toolkit::mINI { public: using Ptr = std::shared_ptr; using Event = std::function; - static Ptr createPlayer(const toolkit::EventPoller::Ptr &poller, const std::string &strUrl); + static Ptr createPlayer(const toolkit::EventPoller::Ptr &poller, const std::string &strUrl, const std::string &schema = ""); PlayerBase(); @@ -165,6 +178,10 @@ public: * [AUTO-TRANSLATED:8fb31d43] */ virtual void setOnResume(const std::function &cb) = 0; + + virtual size_t getRecvSpeed() { return 0; } + virtual size_t getRecvTotalBytes() { return 0; } + virtual std::shared_ptr getSockInfo() const { return nullptr; } protected: virtual void onResume() = 0; @@ -224,8 +241,11 @@ public: return _delegate ? _delegate->getTracks(ready) : Parent::getTracks(ready); } - std::shared_ptr getSockInfo() const { - return std::dynamic_pointer_cast(_delegate); + std::shared_ptr getSockInfo() const override { + auto ret = std::dynamic_pointer_cast(_delegate); + if (!ret) + ret = _delegate ? _delegate->getSockInfo() : Parent::getSockInfo(); + return ret; } void setMediaSource(const MediaSource::Ptr &src) override { @@ -256,6 +276,14 @@ public: _on_resume = cb; } + size_t getRecvSpeed() override { + return _delegate ? _delegate->getRecvSpeed() : Parent::getRecvSpeed(); + } + + size_t getRecvTotalBytes() override { + return _delegate ? _delegate->getRecvTotalBytes() : Parent::getRecvTotalBytes(); + } + protected: void onShutdown(const toolkit::SockException &ex) override { if (_on_shutdown) { diff --git a/src/Player/PlayerProxy.cpp b/src/Player/PlayerProxy.cpp index 9ccd701f..1e1037bb 100644 --- a/src/Player/PlayerProxy.cpp +++ b/src/Player/PlayerProxy.cpp @@ -32,7 +32,7 @@ PlayerProxy::PlayerProxy( setOnClose(nullptr); setOnConnect(nullptr); setOnDisconnect(nullptr); - + _reconnect_delay_min = reconnect_delay_min > 0 ? reconnect_delay_min : 2; _reconnect_delay_max = reconnect_delay_max > 0 ? reconnect_delay_max : 60; _reconnect_delay_step = reconnect_delay_step > 0 ? reconnect_delay_step : 3; @@ -42,6 +42,16 @@ PlayerProxy::PlayerProxy( (*this)[Client::kWaitTrackReady] = false; } +void PlayerProxy::update(const std::string &url, const toolkit::mINI &args) { + CHECK(getPoller()->isCurrentThread()); + _pull_url = url; + this->mINI::clear(); + (*this)[Client::kWaitTrackReady] = false; + for (auto &pr : args) { + (*this)[pr.first] = pr.second; + } +} + void PlayerProxy::setPlayCallbackOnce(function cb) { _on_play = std::move(cb); } @@ -51,15 +61,14 @@ void PlayerProxy::setOnClose(function cb) { } void PlayerProxy::setOnDisconnect(std::function cb) { - _on_disconnect = cb ? std::move(cb) : [] () {}; + _on_disconnect = cb ? std::move(cb) : []() {}; } -void PlayerProxy::setOnConnect(std::function cb) { - _on_connect = cb ? std::move(cb) : [](const TranslationInfo&) {}; +void PlayerProxy::setOnConnect(std::function cb) { + _on_connect = cb ? std::move(cb) : [](const TranslationInfo &) {}; } -void PlayerProxy::setTranslationInfo() -{ +void PlayerProxy::setTranslationInfo() { _transtalion_info.byte_speed = _media_src ? _media_src->getBytesSpeed() : -1; _transtalion_info.start_time_stamp = _media_src ? _media_src->getCreateStamp() : 0; _transtalion_info.stream_info.clear(); @@ -72,22 +81,21 @@ void PlayerProxy::setTranslationInfo() back.codec_type = track->getTrackType(); back.codec_name = track->getCodecName(); switch (back.codec_type) { - case TrackAudio : { + case TrackAudio: { auto audio_track = dynamic_pointer_cast(track); back.audio_sample_rate = audio_track->getAudioSampleRate(); back.audio_channel = audio_track->getAudioChannel(); back.audio_sample_bit = audio_track->getAudioSampleBit(); break; } - case TrackVideo : { + case TrackVideo: { auto video_track = dynamic_pointer_cast(track); back.video_width = video_track->getVideoWidth(); back.video_height = video_track->getVideoHeight(); back.video_fps = video_track->getVideoFps(); break; } - default: - break; + default: break; } } } @@ -101,16 +109,20 @@ static int getMaxTrackSize(const std::string &url) { return 2; } -void PlayerProxy::play(const string &strUrlTmp) { - _option.max_track = getMaxTrackSize(strUrlTmp); +void PlayerProxy::play(const string &url) { + _pull_url = url; + _option.max_track = getMaxTrackSize(_pull_url); weak_ptr weakSelf = shared_from_this(); std::shared_ptr piFailedCnt(new int(0)); // 连续播放失败次数 - setOnPlayResult([weakSelf, strUrlTmp, piFailedCnt](const SockException &err) { + setOnPlayResult([weakSelf, piFailedCnt](const SockException &err) { auto strongSelf = weakSelf.lock(); if (!strongSelf) { return; } - + if (err) { + NOTICE_EMIT(BroadcastPlayerProxyFailedArgs, Broadcast::kBroadcastPlayerProxyFailed, *strongSelf, err); + strongSelf->_status = std::make_shared(std::string("play failed: ") + err.what()); + } if (strongSelf->_on_play) { strongSelf->_on_play(err); strongSelf->_on_play = nullptr; @@ -118,7 +130,8 @@ void PlayerProxy::play(const string &strUrlTmp) { if (!err) { // 取消定时器,避免hls拉流索引文件因为网络波动失败重连成功后出现循环重试的情况 [AUTO-TRANSLATED:91e5f0c8] - // Cancel the timer to avoid the situation where the hls stream index file fails to reconnect due to network fluctuations and then retries in a loop after successful reconnection + // Cancel the timer to avoid the situation where the hls stream index file fails to reconnect due to network fluctuations and then retries in a loop + // after successful reconnection strongSelf->_timer.reset(); strongSelf->_live_ticker.resetTime(); strongSelf->_live_status = 0; @@ -127,25 +140,34 @@ void PlayerProxy::play(const string &strUrlTmp) { *piFailedCnt = 0; // 连续播放失败次数清0 strongSelf->onPlaySuccess(); strongSelf->setTranslationInfo(); - strongSelf->_on_connect(strongSelf->_transtalion_info); + strongSelf->_on_connect(strongSelf->_transtalion_info); - InfoL << "play " << strUrlTmp << " success"; + InfoL << "play " << strongSelf->_pull_url << " success"; + strongSelf->_status = std::make_shared("playing"); } else if (*piFailedCnt < strongSelf->_retry_count || strongSelf->_retry_count < 0) { // 播放失败,延时重试播放 [AUTO-TRANSLATED:d7537c9c] // Play failed, retry playing with delay strongSelf->_on_disconnect(); - strongSelf->rePlay(strUrlTmp, (*piFailedCnt)++); + strongSelf->rePlay((*piFailedCnt)++); } else { // 达到了最大重试次数,回调关闭 [AUTO-TRANSLATED:610f31f3] // Reached the maximum number of retries, callback to close strongSelf->_on_close(err); } }); - setOnShutdown([weakSelf, strUrlTmp, piFailedCnt](const SockException &err) { + setOnShutdown([weakSelf, piFailedCnt](const SockException &err) { auto strongSelf = weakSelf.lock(); if (!strongSelf) { return; } + if (err) { + NOTICE_EMIT(BroadcastPlayerProxyFailedArgs, Broadcast::kBroadcastPlayerProxyFailed, *strongSelf, err); + } + if (strongSelf->_on_play) { + strongSelf->_on_play(err); + strongSelf->_on_play = nullptr; + } + strongSelf->_status = std::make_shared(std::string("play shutdown: ") + err.what()); // 注销直接拉流代理产生的流:#532 [AUTO-TRANSLATED:c6343a3b] // Unregister the stream generated by the direct stream proxy: #532 @@ -177,7 +199,7 @@ void PlayerProxy::play(const string &strUrlTmp) { // Play interrupted abnormally, retry playing with delay if (*piFailedCnt < strongSelf->_retry_count || strongSelf->_retry_count < 0) { strongSelf->_repull_count++; - strongSelf->rePlay(strUrlTmp, (*piFailedCnt)++); + strongSelf->rePlay((*piFailedCnt)++); } else { // 达到了最大重试次数,回调关闭 [AUTO-TRANSLATED:610f31f3] // Reached the maximum number of retries, callback to close @@ -185,13 +207,14 @@ void PlayerProxy::play(const string &strUrlTmp) { } }); try { - MediaPlayer::play(strUrlTmp); + _status = std::make_shared("connecting"); + MediaPlayer::play(_pull_url ); } catch (std::exception &ex) { + _status = std::make_shared(std::string("play failed: ") + ex.what()); ErrorL << ex.what(); onPlayResult(SockException(Err_other, ex.what())); return; } - _pull_url = strUrlTmp; setDirectProxy(); } @@ -231,39 +254,29 @@ PlayerProxy::~PlayerProxy() { } } -void PlayerProxy::rePlay(const string &strUrl, int iFailedCnt) { +void PlayerProxy::rePlay(int iFailedCnt) { auto iDelay = MAX(_reconnect_delay_min * 1000, MIN(iFailedCnt * _reconnect_delay_step * 1000, _reconnect_delay_max * 1000)); weak_ptr weakSelf = shared_from_this(); - _timer = std::make_shared( - iDelay / 1000.0f, - [weakSelf, strUrl, iFailedCnt]() { - // 播放失败次数越多,则延时越长 [AUTO-TRANSLATED:5af39264] - // The more times the playback fails, the longer the delay - auto strongPlayer = weakSelf.lock(); - if (!strongPlayer) { - return false; - } - WarnL << "重试播放[" << iFailedCnt << "]:" << strUrl; - strongPlayer->MediaPlayer::play(strUrl); - strongPlayer->setDirectProxy(); + _timer = std::make_shared(iDelay / 1000.0f, [weakSelf, iFailedCnt]() { + // 播放失败次数越多,则延时越长 [AUTO-TRANSLATED:5af39264] + // The more times the playback fails, the longer the delay + auto strongPlayer = weakSelf.lock(); + if (!strongPlayer) { return false; - }, - getPoller()); + } + WarnL << "重试播放[" << iFailedCnt << "]:" << strongPlayer->_pull_url; + strongPlayer->MediaPlayer::play(strongPlayer->_pull_url); + strongPlayer->setDirectProxy(); + return false; + }, getPoller()); } bool PlayerProxy::close(MediaSource &sender) { // 通知其停止推流 [AUTO-TRANSLATED:d69d10d8] // Notify it to stop pushing the stream - weak_ptr weakSelf = dynamic_pointer_cast(shared_from_this()); - getPoller()->async_first([weakSelf]() { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) { - return; - } - strongSelf->_muxer.reset(); - strongSelf->setMediaSource(nullptr); - strongSelf->teardown(); - }); + _muxer = nullptr; + setMediaSource(nullptr); + teardown(); _on_close(SockException(Err_shutdown, "closed by user")); WarnL << "close media: " << sender.getUrl(); return true; @@ -293,6 +306,10 @@ float PlayerProxy::getLossRate(MediaSource &sender, TrackType type) { return getPacketLossRate(type); } +toolkit::EventPoller::Ptr PlayerProxy::getOwnerPoller(MediaSource &sender) { + return getPoller(); +} + TranslationInfo PlayerProxy::getTranslationInfo() { return _transtalion_info; } @@ -312,10 +329,10 @@ void PlayerProxy::onPlaySuccess() { // rtmp拉流代理 [AUTO-TRANSLATED:21173335] // Rtmp stream proxy if (reset_when_replay || !_muxer) { - auto old = _option.enable_rtmp; + auto old = _option.enable_rtmp; _option.enable_rtmp = false; _muxer = std::make_shared(_tuple, getDuration(), _option); - _option.enable_rtmp = old; + _option.enable_rtmp = old; } } else { // 其他拉流代理 [AUTO-TRANSLATED:e5f2e45d] @@ -360,6 +377,12 @@ void PlayerProxy::onPlaySuccess() { int PlayerProxy::getStatus() { return _live_status.load(); } + +std::string PlayerProxy::getStatusStr() const { + auto status = _status; + return status ? *status : "unknown"; +} + uint64_t PlayerProxy::getLiveSecs() { if (_live_status == 0) { return _live_secs + _live_ticker.elapsedTime() / 1000; diff --git a/src/Player/PlayerProxy.h b/src/Player/PlayerProxy.h index 304d22a6..853e6632 100644 --- a/src/Player/PlayerProxy.h +++ b/src/Player/PlayerProxy.h @@ -18,8 +18,7 @@ namespace mediakit { -struct StreamInfo -{ +struct StreamInfo { TrackType codec_type; std::string codec_name; int bitrate; @@ -30,8 +29,7 @@ struct StreamInfo int video_height; float video_fps; - StreamInfo() - { + StreamInfo() { codec_type = TrackInvalid; codec_name = "none"; bitrate = -1; @@ -44,14 +42,12 @@ struct StreamInfo } }; -struct TranslationInfo -{ +struct TranslationInfo { std::vector stream_info; int byte_speed; uint64_t start_time_stamp; - TranslationInfo() - { + TranslationInfo() { byte_speed = -1; start_time_stamp = 0; } @@ -133,6 +129,7 @@ public: int totalReaderCount(); int getStatus(); + std::string getStatusStr() const; uint64_t getLiveSecs(); uint64_t getRePullCount(); @@ -143,6 +140,8 @@ public: const MediaTuple& getMediaTuple() const { return _tuple; } const ProtocolOption& getOption() const { return _option; } + void update(const std::string &url, const toolkit::mINI &args); + private: // MediaSourceEvent override bool close(MediaSource &sender) override; @@ -151,13 +150,15 @@ private: std::string getOriginUrl(MediaSource &sender) const override; std::shared_ptr getOriginSock(MediaSource &sender) const override; float getLossRate(MediaSource &sender, TrackType type) override; + toolkit::EventPoller::Ptr getOwnerPoller(MediaSource &sender) override; - void rePlay(const std::string &strUrl, int iFailedCnt); + void rePlay(int iFailedCnt); void onPlaySuccess(); void setDirectProxy(); void setTranslationInfo(); private: + std::shared_ptr _status; int _retry_count; int _reconnect_delay_min; int _reconnect_delay_max; diff --git a/src/Pusher/MediaPusher.h b/src/Pusher/MediaPusher.h index cd9c2c9f..2b53e1d7 100644 --- a/src/Pusher/MediaPusher.h +++ b/src/Pusher/MediaPusher.h @@ -35,6 +35,7 @@ public: void setOnCreateSocket(toolkit::Socket::onCreateSocket cb); std::shared_ptr getSrc() { return _src.lock(); } const std::string& getUrl() const { return _url; } + private: std::weak_ptr _src; toolkit::EventPoller::Ptr _poller; diff --git a/src/Pusher/PusherBase.cpp b/src/Pusher/PusherBase.cpp index 63b2011d..496a726f 100644 --- a/src/Pusher/PusherBase.cpp +++ b/src/Pusher/PusherBase.cpp @@ -13,13 +13,45 @@ #include "Rtsp/RtspPusher.h" #include "Rtmp/RtmpPusher.h" #ifdef ENABLE_SRT -#include "Srt/SrtPusher.h" +#include "../srt/SrtPusher.h" #endif // ENABLE_SRT +#ifdef ENABLE_WEBRTC +#include "../webrtc/WebRtcProxyPusher.h" +#endif // ENABLE_WEBRTC using namespace toolkit; namespace mediakit { +static bool checkMediaSourceAndUrlMatch(const MediaSource::Ptr &src, const std::string &url) { + std::string prefix = findSubString(url.data(), NULL, "://"); + + if (strcasecmp("rtsps", prefix.data()) == 0 || strcasecmp("rtsp", prefix.data()) == 0 || + strcasecmp("webrtcs", prefix.data()) == 0 || strcasecmp("webrtc", prefix.data()) == 0 ) { + auto rtsp_src = std::dynamic_pointer_cast(src); + if (!rtsp_src) { + return false; + } + } + + if (strcasecmp("rtmp", prefix.data()) == 0 || strcasecmp("rtmps", prefix.data()) == 0) { + auto rtmp_src = std::dynamic_pointer_cast(src); + if (!rtmp_src) { + return false; + } + } + +#ifdef ENABLE_SRT + if (strcasecmp("srt", prefix.data()) == 0) { + auto ts_src = std::dynamic_pointer_cast(src); + if (!ts_src) { + return false; + } + } +#endif // ENABLE_SRT + return true; +} + PusherBase::Ptr PusherBase::createPusher(const EventPoller::Ptr &in_poller, const MediaSource::Ptr &src, const std::string & url) { @@ -35,6 +67,10 @@ PusherBase::Ptr PusherBase::createPusher(const EventPoller::Ptr &in_poller, delete ptr; } }; + if (!checkMediaSourceAndUrlMatch(src, url)) { + throw std::invalid_argument(" media source (schema) and push url not match"); + } + std::string prefix = findSubString(url.data(), NULL, "://"); if (strcasecmp("rtsps",prefix.data()) == 0) { @@ -59,6 +95,11 @@ PusherBase::Ptr PusherBase::createPusher(const EventPoller::Ptr &in_poller, } #endif//ENABLE_SRT +#ifdef ENABLE_WEBRTC + if ((strcasecmp("webrtc", prefix.data()) == 0 || strcasecmp("webrtcs", prefix.data()) == 0)) { + return PusherBase::Ptr(new WebRtcProxyPusherImp(poller, std::dynamic_pointer_cast(src)), release_func); + } +#endif//ENABLE_WEBRTC throw std::invalid_argument("not supported push schema:" + url); } diff --git a/src/Pusher/PusherBase.h b/src/Pusher/PusherBase.h index 859722ed..269a5f38 100644 --- a/src/Pusher/PusherBase.h +++ b/src/Pusher/PusherBase.h @@ -67,6 +67,9 @@ public: */ virtual void setOnShutdown(const Event &cb) = 0; + virtual size_t getSendSpeed() { return 0; } + virtual size_t getSendTotalBytes() { return 0; } + protected: virtual void onShutdown(const toolkit::SockException &ex) = 0; virtual void onPublishResult(const toolkit::SockException &ex) = 0; @@ -133,6 +136,14 @@ public: _on_shutdown = cb; } + size_t getSendSpeed() override { + return _delegate ? _delegate->getSendSpeed() : Parent::getSendSpeed(); + } + + size_t getSendTotalBytes() override { + return _delegate ? _delegate->getSendTotalBytes() : Parent::getSendTotalBytes(); + } + protected: void onShutdown(const toolkit::SockException &ex) override { if (_on_shutdown) { diff --git a/src/Pusher/PusherProxy.cpp b/src/Pusher/PusherProxy.cpp index 742df665..fef8f9f8 100644 --- a/src/Pusher/PusherProxy.cpp +++ b/src/Pusher/PusherProxy.cpp @@ -115,7 +115,13 @@ void PusherProxy::rePublish(const string &dst_url, int failed_cnt) { return false; } WarnL << "推流重试[" << failed_cnt << "]:" << dst_url; - strong_self->MediaPusher::publish(dst_url); + try { + strong_self->MediaPusher::publish(dst_url); + } catch (std::exception &e) { + WarnL << e.what(); + // 回调推流失败,一般是媒体注销了 + strong_self->_on_close(SockException(Err_other, e.what())); + } return false; }, getPoller()); diff --git a/src/Record/HlsMaker.cpp b/src/Record/HlsMaker.cpp index 9905b4be..c5df3a9e 100644 --- a/src/Record/HlsMaker.cpp +++ b/src/Record/HlsMaker.cpp @@ -17,7 +17,7 @@ using namespace std; namespace mediakit { HlsMaker::HlsMaker(bool is_fmp4, float seg_duration, uint32_t seg_number, bool seg_keep) { - _is_fmp4 = is_fmp4; + _is_fmp4 = is_fmp4; // 最小允许设置为0,0个切片代表点播 [AUTO-TRANSLATED:19235e8e] // Minimum allowed setting is 0, 0 slices represent on-demand _seg_number = seg_number; diff --git a/src/Record/HlsMediaSource.cpp b/src/Record/HlsMediaSource.cpp index e1d9b127..754b994e 100644 --- a/src/Record/HlsMediaSource.cpp +++ b/src/Record/HlsMediaSource.cpp @@ -15,9 +15,37 @@ using namespace toolkit; namespace mediakit { -HlsCookieData::HlsCookieData(const MediaInfo &info, const std::shared_ptr &sock_info) { +class SockInfoImp : public SockInfo { +public: + using Ptr = std::shared_ptr; + + std::string get_local_ip() override { return _local_ip; } + + uint16_t get_local_port() override { return _local_port; } + + std::string get_peer_ip() override { return _peer_ip; } + + uint16_t get_peer_port() override { return _peer_port; } + + std::string getIdentifier() const override { return _identifier; } + + std::string _local_ip; + std::string _peer_ip; + std::string _identifier; + uint16_t _local_port; + uint16_t _peer_port; +}; + +HlsCookieData::HlsCookieData(const MediaInfo &info, const std::shared_ptr &session) { _info = info; + auto sock_info = std::make_shared(); + sock_info->_identifier = session->getIdentifier(); + sock_info->_peer_ip = session->get_peer_ip(); + sock_info->_peer_port = session->get_peer_port(); + sock_info->_local_ip = session->get_local_ip(); + sock_info->_local_port = session->get_local_port(); _sock_info = sock_info; + _session = session; _added = std::make_shared(false); addReaderCount(); } @@ -34,10 +62,10 @@ void HlsCookieData::addReaderCount() { // HlsMediaSource has been destroyed *added = false; }); - auto info = _sock_info; - _ring_reader->setGetInfoCB([info]() { + std::weak_ptr weak_session = _session; + _ring_reader->setGetInfoCB([weak_session]() { Any ret; - ret.set(info); + ret.set(std::static_pointer_cast(weak_session.lock())); return ret; }); } diff --git a/src/Record/HlsMediaSource.h b/src/Record/HlsMediaSource.h index 53f74520..64ace6fb 100644 --- a/src/Record/HlsMediaSource.h +++ b/src/Record/HlsMediaSource.h @@ -14,6 +14,7 @@ #include "Common/MediaSource.h" #include "Util/TimeTicker.h" #include "Util/RingBuffer.h" +#include "Network/Session.h" #include namespace mediakit { @@ -89,7 +90,7 @@ class HlsCookieData { public: using Ptr = std::shared_ptr; - HlsCookieData(const MediaInfo &info, const std::shared_ptr &sock_info); + HlsCookieData(const MediaInfo &info, const std::shared_ptr &session); ~HlsCookieData(); void addByteUsage(size_t bytes); @@ -106,6 +107,7 @@ private: toolkit::Ticker _ticker; std::weak_ptr _src; std::shared_ptr _sock_info; + std::weak_ptr _session; HlsMediaSource::RingType::RingReader::Ptr _ring_reader; }; diff --git a/src/Record/MP4Demuxer.cpp b/src/Record/MP4Demuxer.cpp index 187224fc..27bdce51 100644 --- a/src/Record/MP4Demuxer.cpp +++ b/src/Record/MP4Demuxer.cpp @@ -138,7 +138,7 @@ Frame::Ptr MP4Demuxer::readFrame(bool &keyFrame, bool &eof) { } } -Frame::Ptr MP4Demuxer::makeFrame(uint32_t track_id, const Buffer::Ptr &buf, int64_t pts, int64_t dts) { +Frame::Ptr MP4Demuxer::makeFrame(uint32_t track_id, Buffer::Ptr buf, int64_t pts, int64_t dts) { auto it = _tracks.find(track_id); if (it == _tracks.end()) { return nullptr; @@ -198,11 +198,11 @@ void MultiMP4Demuxer::openMP4(const string &files_string) { std::vector files; if (File::is_dir(files_string)) { File::scanDir(files_string, [&](const string &path, bool is_dir) { - if (!is_dir) { + if (!is_dir && end_with(path, ".mp4")) { files.emplace_back(path); } return true; - }); + }, true); std::sort(files.begin(), files.end()); } else { files = split(files_string, ";"); @@ -218,7 +218,10 @@ void MultiMP4Demuxer::openMP4(const string &files_string) { CHECK(!_demuxers.empty()); _it = _demuxers.begin(); for (auto &track : _it->second->getTracks(false)) { - _tracks.emplace(track->getIndex(), track->clone()); + auto clone_track(track->clone()); + clone_track->setIndex(clone_track->getTrackType()); + _tracks.emplace(clone_track->getIndex(), clone_track); + DebugL << "track index: " << track->getIndex() << " -> " << clone_track->getIndex(); } } @@ -244,6 +247,7 @@ Frame::Ptr MultiMP4Demuxer::readFrame(bool &keyFrame, bool &eof) { for (;;) { auto ret = _it->second->readFrame(keyFrame, eof); if (ret) { + ret->setIndex(ret->getTrackType()); auto it = _tracks.find(ret->getIndex()); if (it != _tracks.end()) { auto ret2 = std::make_shared(ret); diff --git a/src/Record/MP4Demuxer.h b/src/Record/MP4Demuxer.h index 4d5d168e..8299dbfc 100644 --- a/src/Record/MP4Demuxer.h +++ b/src/Record/MP4Demuxer.h @@ -96,7 +96,7 @@ private: int getAllTracks(); void onVideoTrack(uint32_t track_id, uint8_t object, int width, int height, const void *extra, size_t bytes); void onAudioTrack(uint32_t track_id, uint8_t object, int channel_count, int bit_per_sample, int sample_rate, const void *extra, size_t bytes); - Frame::Ptr makeFrame(uint32_t track_id, const toolkit::Buffer::Ptr &buf, int64_t pts, int64_t dts); + Frame::Ptr makeFrame(uint32_t track_id, toolkit::Buffer::Ptr buf, int64_t pts, int64_t dts); private: MP4FileDisk::Ptr _mp4_file; diff --git a/src/Record/MP4Muxer.cpp b/src/Record/MP4Muxer.cpp index 8fd0567e..35dec7fd 100644 --- a/src/Record/MP4Muxer.cpp +++ b/src/Record/MP4Muxer.cpp @@ -19,7 +19,11 @@ using namespace toolkit; namespace mediakit { MP4Muxer::~MP4Muxer() { - closeMP4(); + try { + closeMP4(); + } catch (std::exception &e) { + WarnL << e.what(); + } } void MP4Muxer::openMP4(const string &file) { diff --git a/src/Record/MP4Recorder.cpp b/src/Record/MP4Recorder.cpp index 7576ba6a..0b438dd2 100644 --- a/src/Record/MP4Recorder.cpp +++ b/src/Record/MP4Recorder.cpp @@ -23,7 +23,6 @@ using namespace toolkit; namespace mediakit { MP4Recorder::MP4Recorder(const MediaTuple &tuple, const string &path, size_t max_second) { - _folder_path = path; // ///record 业务逻辑////// [AUTO-TRANSLATED:2e78931a] // ///record Business Logic////// static_cast(_info) = tuple; @@ -44,9 +43,9 @@ MP4Recorder::~MP4Recorder() { void MP4Recorder::createFile() { closeFile(); auto date = getTimeStr("%Y-%m-%d"); - auto file_name = getTimeStr("%H-%M-%S") + "-" + std::to_string(_file_index++) + ".mp4"; - auto full_path = _folder_path + date + "/" + file_name; - auto full_path_tmp = _folder_path + date + "/." + file_name; + auto file_name = date + "-" + getTimeStr("%H-%M-%S") + "-" + std::to_string(_file_index++) + ".mp4"; + auto full_path = _info.folder + date + "/" + file_name; + auto full_path_tmp = _info.folder + date + "/." + file_name; // ///record 业务逻辑////// [AUTO-TRANSLATED:2e78931a] // ///record Business Logic////// @@ -66,7 +65,6 @@ void MP4Recorder::createFile() { _muxer->addTrack(track); } _full_path_tmp = full_path_tmp; - _full_path = full_path; } catch (std::exception &ex) { WarnL << ex.what(); } @@ -75,10 +73,9 @@ void MP4Recorder::createFile() { void MP4Recorder::asyncClose() { auto muxer = _muxer; auto full_path_tmp = _full_path_tmp; - auto full_path = _full_path; auto info = _info; TraceL << "Start close tmp mp4 file: " << full_path_tmp; - WorkThreadPool::Instance().getExecutor()->async([muxer, full_path_tmp, full_path, info]() mutable { + WorkThreadPool::Instance().getExecutor()->async([muxer, full_path_tmp, info]() mutable { info.time_len = muxer->getDuration() / 1000.0f; // 关闭mp4可能非常耗时,所以要放在后台线程执行 [AUTO-TRANSLATED:a7378a11] // Closing mp4 can be very time-consuming, so it should be executed in the background thread @@ -97,9 +94,9 @@ void MP4Recorder::asyncClose() { } // 临时文件名改成正式文件名,防止mp4未完成时被访问 [AUTO-TRANSLATED:541a6f00] // Change the temporary file name to the official file name to prevent access to the mp4 before it is completed - rename(full_path_tmp.data(), full_path.data()); + rename(full_path_tmp.data(), info.file_path.data()); } - TraceL << "Emit mp4 record event: " << full_path; + TraceL << "Emit mp4 record event: " << info.file_path; // 触发mp4录制切片生成事件 [AUTO-TRANSLATED:9959dcd4] // Trigger mp4 recording slice generation event NOTICE_EMIT(BroadcastRecordMP4Args, Broadcast::kBroadcastRecordMP4, info); @@ -120,33 +117,19 @@ void MP4Recorder::flush() { } bool MP4Recorder::inputFrame(const Frame::Ptr &frame) { - if (!(_have_video && frame->getTrackType() == TrackAudio)) { - // 如果有视频且输入的是音频,那么应该忽略切片逻辑 [AUTO-TRANSLATED:fbb15d93] - // If there is video and the input is audio, then the slice logic should be ignored - if (_last_dts == 0) { - // first frame assign dts - _last_dts = frame->dts(); - } else if (_last_dts > frame->dts()) { - // b帧情况下dts时间戳可能回退 [AUTO-TRANSLATED:1de38f77] - // In the case of b-frames, the dts timestamp may regress - _last_dts = MIN(frame->dts(), _last_dts); - } - - auto duration = 5u; // 默认至少一帧5ms - if (frame->dts() > 0 && frame->dts() > _last_dts) { - duration = MAX(duration, frame->dts() - _last_dts); - } - if (!_muxer || ((duration > _max_second * 1000) && (!_have_video || (_have_video && frame->keyFrame())))) { - // 成立条件 [AUTO-TRANSLATED:8c9c6083] - // Conditions for establishment - // 1、_muxer为空 [AUTO-TRANSLATED:fa236097] - // 1. _muxer is empty - // 2、到了切片时间,并且只有音频 [AUTO-TRANSLATED:212e9d23] - // 2. It's time to slice, and there is only audio - // 3、到了切片时间,有视频并且遇到视频的关键帧 [AUTO-TRANSLATED:fa4a71ad] - // 3. It's time to slice, there is video and a video keyframe is encountered - _last_dts = 0; - createFile(); + auto stamp_inc = _delta_stamp[frame->getTrackType()].relativeStamp(frame->pts(), false); + if (!_muxer || (stamp_inc > int64_t(_max_second) * 1000 && (!_have_video || frame->keyFrame()))) { + // 成立条件 [AUTO-TRANSLATED:8c9c6083] + // Conditions for establishment + // 1、_muxer为空 [AUTO-TRANSLATED:fa236097] + // 1. _muxer is empty + // 2、到了切片时间,并且只有音频 [AUTO-TRANSLATED:212e9d23] + // 2. It's time to slice, and there is only audio + // 3、到了切片时间,有视频并且遇到视频的关键帧 [AUTO-TRANSLATED:fa4a71ad] + // 3. It's time to slice, there is video and a video keyframe is encountered + createFile(); + for (auto &ref : _delta_stamp) { + ref.reset(); } } diff --git a/src/Record/MP4Recorder.h b/src/Record/MP4Recorder.h index 6c042b8e..c12585d4 100644 --- a/src/Record/MP4Recorder.h +++ b/src/Record/MP4Recorder.h @@ -70,10 +70,8 @@ private: private: bool _have_video = false; size_t _max_second; - uint64_t _last_dts = 0; - uint64_t _file_index = 0; - std::string _folder_path; - std::string _full_path; + DeltaStamp _delta_stamp[TrackMax]; + std::atomic _file_index { 0 }; std::string _full_path_tmp; RecordInfo _info; MP4Muxer::Ptr _muxer; diff --git a/src/Record/MPEG.cpp b/src/Record/MPEG.cpp index 483cf12c..65a9dfd7 100644 --- a/src/Record/MPEG.cpp +++ b/src/Record/MPEG.cpp @@ -68,7 +68,7 @@ bool MpegMuxer::inputFrame(const Frame::Ptr &frame) { } case CodecAAC: { - CHECK(frame->prefixSize(), "Mpeg muxer required aac frame with adts heade"); + CHECK(frame->prefixSize(), "Mpeg muxer required aac frame with adts header"); } default: { diff --git a/src/Record/Recorder.cpp b/src/Record/Recorder.cpp index 51b9e031..3247b448 100644 --- a/src/Record/Recorder.cpp +++ b/src/Record/Recorder.cpp @@ -25,13 +25,15 @@ namespace mediakit { string Recorder::getRecordPath(Recorder::type type, const MediaTuple& tuple, const string &customized_path) { GET_CONFIG(bool, enableVhost, General::kEnableVhost); switch (type) { + case Recorder::type_hls_fmp4: case Recorder::type_hls: { GET_CONFIG(string, hlsPath, Protocol::kHlsSavePath); string m3u8FilePath; + auto tail = type == Recorder::type_hls ? "/hls.m3u8" : "/hls.fmp4.m3u8"; if (enableVhost) { - m3u8FilePath = tuple.shortUrl() + "/hls.m3u8"; + m3u8FilePath = tuple.shortUrl() + tail; } else { - m3u8FilePath = tuple.app + "/" + tuple.stream + "/hls.m3u8"; + m3u8FilePath = tuple.app + "/" + tuple.stream + tail; } //Here we use the customized file path. if (!customized_path.empty()) { @@ -54,20 +56,6 @@ string Recorder::getRecordPath(Recorder::type type, const MediaTuple& tuple, con } return File::absolutePath(mp4FilePath, recordPath); } - case Recorder::type_hls_fmp4: { - GET_CONFIG(string, hlsPath, Protocol::kHlsSavePath); - string m3u8FilePath; - if (enableVhost) { - m3u8FilePath = tuple.shortUrl() + "/hls.fmp4.m3u8"; - } else { - m3u8FilePath = tuple.app + "/" + tuple.stream + "/hls.fmp4.m3u8"; - } - // Here we use the customized file path. - if (!customized_path.empty()) { - return File::absolutePath(m3u8FilePath, customized_path); - } - return File::absolutePath(m3u8FilePath, hlsPath); - } default: return ""; } } diff --git a/src/Record/Recorder.h b/src/Record/Recorder.h index 1e73453d..e2107daf 100644 --- a/src/Record/Recorder.h +++ b/src/Record/Recorder.h @@ -13,6 +13,7 @@ #include #include +#include namespace mediakit { class MediaSinkInterface; @@ -26,6 +27,11 @@ struct MediaTuple { std::string shortUrl() const { return vhost + '/' + app + '/' + stream; } + + MediaTuple() = default; + MediaTuple(std::string vhost, std::string app, std::string stream, std::string params = "") + : vhost(std::move(vhost)), app(std::move(app)), stream(std::move(stream)), params(std::move(params)) { + } }; class RecordInfo: public MediaTuple { diff --git a/src/Rtcp/Rtcp.cpp b/src/Rtcp/Rtcp.cpp index d9c07b29..7d44bf77 100644 --- a/src/Rtcp/Rtcp.cpp +++ b/src/Rtcp/Rtcp.cpp @@ -114,7 +114,7 @@ string RtcpHeader::dumpHeader() const { printer << "pt:" << rtcpTypeToStr((RtcpType)pt) << "\r\n"; printer << "size:" << getSize() << "\r\n"; printer << "--------\r\n"; - return std::move(printer); + return printer; } string RtcpHeader::dumpString() const { @@ -322,7 +322,7 @@ string RtcpSR::dumpString() const { printer << "---- item:" << i++ << " ----\r\n"; printer << item->dumpString(); } - return std::move(printer); + return printer; } #define CHECK_MIN_SIZE(size, kMinSize) \ @@ -385,7 +385,7 @@ string ReportItem::dumpString() const { printer << "jitter:" << jitter << "\r\n"; printer << "last_sr_stamp:" << last_sr_stamp << "\r\n"; printer << "delay_since_last_sr:" << delay_since_last_sr << "\r\n"; - return std::move(printer); + return printer; } void ReportItem::net2Host() { @@ -419,7 +419,7 @@ string RtcpRR::dumpString() const { printer << "---- item:" << i++ << " ----\r\n"; printer << item->dumpString(); } - return std::move(printer); + return printer; } void RtcpRR::net2Host(size_t size) { @@ -467,7 +467,7 @@ string SdesChunk::dumpString() const { printer << "type:" << sdesTypeToStr((SdesType)type) << "\r\n"; printer << "txt_len:" << (int)txt_len << "\r\n"; printer << "text:" << (txt_len ? string(text, txt_len) : "") << "\r\n"; - return std::move(printer); + return printer; } ///////////////////////////////////////////////////////////////////////////// @@ -506,7 +506,7 @@ string RtcpSdes::dumpString() const { printer << "---- item:" << i++ << " ----\r\n"; printer << item->dumpString(); } - return std::move(printer); + return printer; } void RtcpSdes::net2Host(size_t size) { @@ -627,7 +627,7 @@ string RtcpFB::dumpString() const { } default: /*不可达*/ assert(0); break; } - return std::move(printer); + return printer; } void RtcpFB::net2Host(size_t size) { @@ -684,7 +684,7 @@ string RtcpBye::dumpString() const { printer << "ssrc:" << *ssrc << "\r\n"; } printer << "reason:" << getReason(); - return std::move(printer); + return printer; } void RtcpBye::net2Host(size_t size) { @@ -719,7 +719,7 @@ string RtcpXRRRTR::dumpString() const { printer << "block_length : " << block_length << "\r\n"; printer << "ntp msw : " << ntpmsw << "\r\n"; printer << "ntp lsw : " << ntplsw << "\r\n"; - return std::move(printer); + return printer; } void RtcpXRRRTR::net2Host(size_t size) { @@ -743,7 +743,7 @@ string RtcpXRDLRRReportItem::dumpString() const { printer << "last RR (lrr) :" << lrr << "\r\n"; printer << "delay since last RR (dlrr): " << dlrr << "\r\n"; - return std::move(printer); + return printer; } void RtcpXRDLRRReportItem::net2Host() { @@ -774,7 +774,7 @@ string RtcpXRDLRR::dumpString() const { printer << "---- item:" << i++ << " ----\r\n"; printer << item->dumpString(); } - return std::move(printer); + return printer; } void RtcpXRDLRR::net2Host(size_t size) { @@ -809,7 +809,7 @@ string RtcpXRTargetBitrateItem::dumpString() const { printer << "Temporal Layer :" << temporal_layer << "\r\n"; printer << "Target Bitrate: " << target_bitrate << "\r\n"; - return std::move(printer); + return printer; } void RtcpXRTargetBitrateItem::net2Host() { @@ -839,7 +839,7 @@ string RtcpXRTargetBitrate::dumpString() const { printer << "---- item:" << i++ << " ----\r\n"; printer << item->dumpString(); } - return std::move(printer); + return printer; } void RtcpXRTargetBitrate::net2Host(size_t size) { diff --git a/src/Rtcp/RtcpFCI.cpp b/src/Rtcp/RtcpFCI.cpp index 64e7b390..e15732ab 100644 --- a/src/Rtcp/RtcpFCI.cpp +++ b/src/Rtcp/RtcpFCI.cpp @@ -153,7 +153,7 @@ string FCI_REMB::dumpString() const { for (auto &ssrc : ((FCI_REMB *)this)->getSSRC()) { printer << ssrc << " "; } - return std::move(printer); + return printer; } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -208,7 +208,7 @@ string FCI_NACK::dumpString() const { } ++pid; } - return std::move(printer); + return printer; } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -267,7 +267,7 @@ uint16_t RunLengthChunk::getRunLength() const { string RunLengthChunk::dumpString() const { _StrPrinter printer; printer << "run length chunk, symbol:" << (int)symbol << ", run length:" << getRunLength(); - return std::move(printer); + return printer; } /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -361,7 +361,7 @@ string StatusVecChunk::dumpString() const { for (auto &item : vec) { printer << (int)item << " "; } - return std::move(printer); + return printer; } /////////////////////////////////////////////////////// @@ -496,7 +496,7 @@ string FCI_TWCC::dumpString(size_t total_size) const { printer << "rtp seq:" << pr.first << ", packet status:" << (int)(pr.second.first) << ", delta:" << pr.second.second << "\n"; } - return std::move(printer); + return printer; } static void appendDeltaString(string &delta_str, FCI_TWCC::TwccPacketStatus &status, int count) { diff --git a/src/Rtmp/FlvMuxer.cpp b/src/Rtmp/FlvMuxer.cpp index 936a9b9a..1ed38f27 100644 --- a/src/Rtmp/FlvMuxer.cpp +++ b/src/Rtmp/FlvMuxer.cpp @@ -13,7 +13,6 @@ #include "Rtmp/utils.h" #include "Http/HttpSession.h" -#define FILE_BUF_SIZE (64 * 1024) using namespace std; using namespace toolkit; @@ -49,7 +48,7 @@ void FlvMuxer::start(const EventPoller::Ptr &poller, const RtmpMediaSource::Ptr _ring_reader = media->getRing()->attach(poller); _ring_reader->setGetInfoCB([weak_self]() { Any ret; - ret.set(dynamic_pointer_cast(weak_self.lock())); + ret.set(dynamic_pointer_cast(weak_self.lock())); return ret; }); _ring_reader->setDetachCB([weak_self]() { @@ -168,11 +167,12 @@ void FlvRecorder::startRecord(const EventPoller::Ptr &poller, const string &vhos void FlvRecorder::startRecord(const EventPoller::Ptr &poller, const RtmpMediaSource::Ptr &media, const string &file_path) { + GET_CONFIG(uint32_t, flvBufSize, Record::kFileBufSize); stop(); lock_guard lck(_file_mtx); // 开辟文件写缓存 [AUTO-TRANSLATED:22d1c17f] // Allocate file write cache. - std::shared_ptr fileBuf(new char[FILE_BUF_SIZE], [](char *ptr) { + std::shared_ptr fileBuf(new char[flvBufSize], [](char *ptr) { if (ptr) { delete[] ptr; } @@ -191,7 +191,7 @@ void FlvRecorder::startRecord(const EventPoller::Ptr &poller, const RtmpMediaSou // 设置文件写缓存 [AUTO-TRANSLATED:a767e55c] // Set the file write cache. - setvbuf(_file.get(), fileBuf.get(), _IOFBF, FILE_BUF_SIZE); + setvbuf(_file.get(), fileBuf.get(), _IOFBF, flvBufSize); start(poller, media); } diff --git a/src/Rtmp/FlvPlayer.cpp b/src/Rtmp/FlvPlayer.cpp index 8349ab8f..f890ca9c 100644 --- a/src/Rtmp/FlvPlayer.cpp +++ b/src/Rtmp/FlvPlayer.cpp @@ -26,6 +26,7 @@ void FlvPlayer::play(const string &url) { setHeaderTimeout((*this)[Client::kTimeoutMS].as()); setBodyTimeout((*this)[Client::kMediaTimeoutMS].as()); setMethod("GET"); + addCustomHeader(this); sendRequest(url); } @@ -76,4 +77,12 @@ void FlvPlayer::onRecvRtmpPacket(RtmpPacket::Ptr packet) { onRtmpPacket(std::move(packet)); } +size_t FlvPlayer::getRecvSpeed() { + return TcpClient::getRecvSpeed(); +} + +size_t FlvPlayer::getRecvTotalBytes() { + return TcpClient::getRecvTotalBytes(); +} + }//mediakit \ No newline at end of file diff --git a/src/Rtmp/FlvPlayer.h b/src/Rtmp/FlvPlayer.h index 12ba54e3..9328acd9 100644 --- a/src/Rtmp/FlvPlayer.h +++ b/src/Rtmp/FlvPlayer.h @@ -23,6 +23,8 @@ public: void play(const std::string &url) override; void teardown() override; + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; protected: void onResponseHeader(const std::string &status, const HttpHeader &header) override; diff --git a/src/Rtmp/Rtmp.cpp b/src/Rtmp/Rtmp.cpp index 76a9da99..fe1ff1c8 100644 --- a/src/Rtmp/Rtmp.cpp +++ b/src/Rtmp/Rtmp.cpp @@ -55,6 +55,36 @@ AudioMeta::AudioMeta(const AudioTrack::Ptr &audio) { _metadata.set("audiocodecid", Factory::getAmfByCodecId(audio->getCodecId())); } +uint8_t getCodecFlags(CodecId cid) { + switch (cid) { +#define XX(a, b, c) \ + case a: return static_cast(b); + RTMP_CODEC_MAP(XX) +#undef XX + default: return 0; + } +} + +uint32_t getCodecFourCC(CodecId cid) { + switch (cid) { +#define XX(a, b, c) \ + case a: return static_cast(c); + RTMP_CODEC_MAP(XX) +#undef XX + default: return 0; + } +} + +CodecId getFourccCodec(uint32_t id) { + switch (id) { +#define XX(a, b, c) \ + case (uint32_t)c: return a; + RTMP_CODEC_MAP(XX) +#undef XX + default: return CodecInvalid; + } +} + uint8_t getAudioRtmpFlags(const Track::Ptr &track) { track->update(); switch (track->getTrackType()) { @@ -167,7 +197,11 @@ bool RtmpPacket::isVideoKeyFrame() const { bool RtmpPacket::isConfigFrame() const { switch (type_id) { case MSG_AUDIO: { - return (RtmpAudioCodec)getRtmpCodecId() == RtmpAudioCodec::aac && (RtmpAACPacketType)buffer[1] == RtmpAACPacketType::aac_config_header; + switch ((RtmpAudioCodec)getRtmpCodecId()) { + case RtmpAudioCodec::aac: return (RtmpAACPacketType)buffer[1] == RtmpAACPacketType::aac_config_header; + case RtmpAudioCodec::ex_header: return (RtmpPacketType)(buffer[0] & 0x0f) == RtmpPacketType::PacketTypeSequenceStart; + default: return false; + } } case MSG_VIDEO: { if (!isVideoKeyFrame()) { @@ -271,6 +305,8 @@ CodecId parseVideoRtmpPacket(const uint8_t *data, size_t size, RtmpPacketInfo *i switch ((RtmpVideoCodec)ntohl(enhanced_header->fourcc)) { case RtmpVideoCodec::fourcc_av1: info->codec = CodecAV1; break; case RtmpVideoCodec::fourcc_vp9: info->codec = CodecVP9; break; + case RtmpVideoCodec::fourcc_vp8: info->codec = CodecVP8; break; + case RtmpVideoCodec::fourcc_avc1: info->codec = CodecH264; break; case RtmpVideoCodec::fourcc_hevc: info->codec = CodecH265; break; default: WarnL << "Rtmp video codec not supported: " << std::string((char *)data + 1, 4); } @@ -292,6 +328,21 @@ CodecId parseVideoRtmpPacket(const uint8_t *data, size_t size, RtmpPacketInfo *i info->video.h264_pkt_type = (RtmpH264PacketType)classic_header->h264_pkt_type; break; } + case RtmpVideoCodec::vp8: { + CHECK(size >= 0, "Invalid rtmp buffer size: ", size); + info->codec = CodecVP8; + break; + } + case RtmpVideoCodec::vp9: { + CHECK(size >= 0, "Invalid rtmp buffer size: ", size); + info->codec = CodecVP9; + break; + } + case RtmpVideoCodec::av1: { + CHECK(size >= 0, "Invalid rtmp buffer size: ", size); + info->codec = CodecAV1; + break; + } default: WarnL << "Rtmp video codec not supported: " << (int)classic_header->codec_id; break; } } diff --git a/src/Rtmp/Rtmp.h b/src/Rtmp/Rtmp.h index 9b1c135a..aa301bb4 100644 --- a/src/Rtmp/Rtmp.h +++ b/src/Rtmp/Rtmp.h @@ -306,11 +306,15 @@ enum class RtmpVideoCodec : uint32_t { screen_video2 = 6, // Screen video version 2 h264 = 7, // avc h265 = 12, // 国内扩展 - + av1 = 13, // 国内扩展 + vp8 = 14, // 国内扩展 + vp9 = 15, // 国内扩展 // 增强型rtmp FourCC [AUTO-TRANSLATED:442b77fb] // Enhanced rtmp FourCC + fourcc_vp8 = MKBETAG('v', 'p', '0', '8'), fourcc_vp9 = MKBETAG('v', 'p', '0', '9'), fourcc_av1 = MKBETAG('a', 'v', '0', '1'), + fourcc_avc1 = MKBETAG('a', 'v', 'c', '1'), fourcc_hevc = MKBETAG('h', 'v', 'c', '1') }; @@ -354,7 +358,7 @@ enum class RtmpPacketType : uint8_t { //https://rtmp.veriskope.com/pdf/video_file_format_spec_v10_1.pdf // UB [4]; Format of SoundData -enum class RtmpAudioCodec : uint8_t { +enum class RtmpAudioCodec : uint32_t { /** 0 = Linear PCM, platform endian 1 = ADPCM @@ -375,10 +379,30 @@ enum class RtmpAudioCodec : uint8_t { mp3 = 2, g711a = 7, g711u = 8, + ex_header = 9, // Enhanced audio; new, used to signal FOURCC mode aac = 10, - opus = 13 // 国内扩展 + opus = 13, // 国内扩展 + fourcc_opus = MKBETAG('O', 'p', 'u', 's'), + fourcc_mp3 = MKBETAG('.', 'm', 'p', '3'), + fourcc_aac = MKBETAG('m', 'p', '4', 'a'), + fourcc_ac3 = MKBETAG('a', 'c', '-', '3'), + fourcc_flac = MKBETAG('f', 'L', 'a', 'C'), + }; +#define RTMP_CODEC_MAP(XX) \ + XX(CodecH264, RtmpVideoCodec::h264, RtmpVideoCodec::fourcc_avc1) \ + XX(CodecH265, RtmpVideoCodec::h265, RtmpVideoCodec::fourcc_hevc) \ + XX(CodecVP8, RtmpVideoCodec::vp8, RtmpVideoCodec::fourcc_vp8) \ + XX(CodecVP9, RtmpVideoCodec::vp9, RtmpVideoCodec::fourcc_vp9) \ + XX(CodecAV1, RtmpVideoCodec::av1, RtmpVideoCodec::fourcc_av1) \ + XX(CodecAAC, RtmpAudioCodec::aac, RtmpAudioCodec::fourcc_aac) \ + XX(CodecMP3, RtmpAudioCodec::mp3, RtmpAudioCodec::fourcc_mp3) \ + XX(CodecOpus, RtmpAudioCodec::opus, RtmpAudioCodec::fourcc_opus) +uint32_t getCodecFourCC(CodecId cid); +CodecId getFourccCodec(uint32_t id); +uint8_t getCodecFlags(CodecId cid); + // UI8; enum class RtmpAACPacketType : uint8_t { aac_config_header = 0, // AAC sequence header diff --git a/src/Rtmp/RtmpPlayer.cpp b/src/Rtmp/RtmpPlayer.cpp index c758d98d..359030da 100644 --- a/src/Rtmp/RtmpPlayer.cpp +++ b/src/Rtmp/RtmpPlayer.cpp @@ -452,4 +452,12 @@ void RtmpPlayer::seekToMilliSecond(uint32_t seekMS){ }); } +size_t RtmpPlayer::getRecvSpeed() { + return TcpClient::getRecvSpeed(); +} + +size_t RtmpPlayer::getRecvTotalBytes() { + return TcpClient::getRecvTotalBytes(); +} + } /* namespace mediakit */ diff --git a/src/Rtmp/RtmpPlayer.h b/src/Rtmp/RtmpPlayer.h index d11cda2a..423b637b 100644 --- a/src/Rtmp/RtmpPlayer.h +++ b/src/Rtmp/RtmpPlayer.h @@ -37,6 +37,9 @@ public: void speed(float speed) override; void teardown() override; + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; + protected: virtual bool onMetadata(const AMFValue &val) = 0; virtual void onRtmpPacket(RtmpPacket::Ptr chunk_data) = 0; diff --git a/src/Rtmp/RtmpProtocol.cpp b/src/Rtmp/RtmpProtocol.cpp index 191fcbed..daf43564 100644 --- a/src/Rtmp/RtmpProtocol.cpp +++ b/src/Rtmp/RtmpProtocol.cpp @@ -26,11 +26,35 @@ using namespace toolkit; #define S2_FMS_KEY_SIZE 68 #define C1_OFFSET_SIZE 4 + #ifdef ENABLE_OPENSSL #include "Util/SSLBox.h" #include #include +static uint8_t FMSKey[] = { + 0x47, 0x65, 0x6e, 0x75, 0x69, 0x6e, 0x65, 0x20, + 0x41, 0x64, 0x6f, 0x62, 0x65, 0x20, 0x46, 0x6c, + 0x61, 0x73, 0x68, 0x20, 0x4d, 0x65, 0x64, 0x69, + 0x61, 0x20, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x20, 0x30, 0x30, 0x31, // Genuine Adobe Flash Media Server 001 + 0xf0, 0xee, 0xc2, 0x4a, 0x80, 0x68, 0xbe, 0xe8, + 0x2e, 0x00, 0xd0, 0xd1, 0x02, 0x9e, 0x7e, 0x57, + 0x6e, 0xec, 0x5d, 0x2d, 0x29, 0x80, 0x6f, 0xab, + 0x93, 0xb8, 0xe6, 0x36, 0xcf, 0xeb, 0x31, 0xae +}; // 68 + +static uint8_t FPKey[] = { + 0x47, 0x65, 0x6E, 0x75, 0x69, 0x6E, 0x65, 0x20, + 0x41, 0x64, 0x6F, 0x62, 0x65, 0x20, 0x46, 0x6C, + 0x61, 0x73, 0x68, 0x20, 0x50, 0x6C, 0x61, 0x79, + 0x65, 0x72, 0x20, 0x30, 0x30, 0x31, // Genuine Adobe Flash Player 001 + 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, + 0x2E, 0x00, 0xD0, 0xD1, 0x02, 0x9E, 0x7E, 0x57, + 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, + 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE +}; // 62 + static string openssl_HMACsha256(const void *key, size_t key_len, const void *data, size_t data_len){ std::shared_ptr out(new char[32], [](char *ptr) { delete[] ptr; }); unsigned int out_len; @@ -329,8 +353,16 @@ const char* RtmpProtocol::handle_S0S1S2(const char *data, size_t len, const func } // 发送 C2 [AUTO-TRANSLATED:e51c339e] // Send C2 - const char *pcC2 = data + 1; - onSendRawData(obtainBuffer(pcC2, C1_HANDSHARK_SIZE)); + uint8_t *pS1 = (uint8_t*)data + 1; + RtmpHandshake c2(0); + memcpy(&c2, pS1, sizeof(c2)); +#ifdef ENABLE_OPENSSL + if(pS1[4] >=3){ // 复杂握手计算c2 + handle_S1_complex((char*)pS1, c2); + } +#endif + + onSendRawData(obtainBuffer(&c2, C1_HANDSHARK_SIZE)); // 握手结束 [AUTO-TRANSLATED:9df763ff] // Handshake finished _next_step_func = [this](const char *data, size_t len) { @@ -408,7 +440,7 @@ void RtmpProtocol::handle_C1_complex(const char *data){ check_C1_Digest(digest, c1_joined); send_complex_S0S1S2(0, digest); -// InfoL << "schema0"; +// InfoL << "schema0"; } catch (std::exception &) { // 貌似flash从来都不用schema1 [AUTO-TRANSLATED:2c6d140f] // It seems that flash never uses schema1 @@ -426,40 +458,70 @@ void RtmpProtocol::handle_C1_complex(const char *data){ check_C1_Digest(digest, c1_joined); send_complex_S0S1S2(1, digest); -// InfoL << "schema1"; +// InfoL << "schema1"; } catch (std::exception &) { -// WarnL << "try rtmp complex schema1 failed:" << ex.what(); + //WarnL << "try rtmp complex schema1 failed:" << ex.what(); handle_C1_simple(data); } } } -#if !defined(u_int8_t) -#define u_int8_t unsigned char -#endif // !defined(u_int8_t) +void RtmpProtocol::check_S1_Digest(const std::string &digest,const std::string &data){ + auto sha256 = openssl_HMACsha256(FMSKey, S1_FMS_KEY_SIZE, data.data(), data.size()); + if (sha256 != digest) { + throw std::runtime_error("digest mismatched"); + } else { + InfoL << "check rtmp complex handshark success!"; + } +} -static u_int8_t FMSKey[] = { - 0x47, 0x65, 0x6e, 0x75, 0x69, 0x6e, 0x65, 0x20, - 0x41, 0x64, 0x6f, 0x62, 0x65, 0x20, 0x46, 0x6c, - 0x61, 0x73, 0x68, 0x20, 0x4d, 0x65, 0x64, 0x69, - 0x61, 0x20, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x20, 0x30, 0x30, 0x31, // Genuine Adobe Flash Media Server 001 - 0xf0, 0xee, 0xc2, 0x4a, 0x80, 0x68, 0xbe, 0xe8, - 0x2e, 0x00, 0xd0, 0xd1, 0x02, 0x9e, 0x7e, 0x57, - 0x6e, 0xec, 0x5d, 0x2d, 0x29, 0x80, 0x6f, 0xab, - 0x93, 0xb8, 0xe6, 0x36, 0xcf, 0xeb, 0x31, 0xae -}; // 68 +void RtmpProtocol::handle_S1_complex(const char *data,RtmpHandshake &c2){ -static u_int8_t FPKey[] = { - 0x47, 0x65, 0x6E, 0x75, 0x69, 0x6E, 0x65, 0x20, - 0x41, 0x64, 0x6F, 0x62, 0x65, 0x20, 0x46, 0x6C, - 0x61, 0x73, 0x68, 0x20, 0x50, 0x6C, 0x61, 0x79, - 0x65, 0x72, 0x20, 0x30, 0x30, 0x31, // Genuine Adobe Flash Player 001 - 0xF0, 0xEE, 0xC2, 0x4A, 0x80, 0x68, 0xBE, 0xE8, - 0x2E, 0x00, 0xD0, 0xD1, 0x02, 0x9E, 0x7E, 0x57, - 0x6E, 0xEC, 0x5D, 0x2D, 0x29, 0x80, 0x6F, 0xAB, - 0x93, 0xB8, 0xE6, 0x36, 0xCF, 0xEB, 0x31, 0xAE -}; // 62 + const char *s1_start = data; + const char *schema_start = s1_start + 8; + char *digest_start; + std::string digest; + try { + /* c1s1 schema0 + time: 4bytes + version: 4bytes + key: 764bytes + digest: 764bytes + */ + digest = get_C1_digest((uint8_t *) schema_start + C1_SCHEMA_SIZE, &digest_start); + string s1_joined(s1_start, C1_HANDSHARK_SIZE); + s1_joined.erase(digest_start - s1_start, C1_DIGEST_SIZE); + check_S1_Digest(digest, s1_joined); + //InfoL << "schema0"; + } catch (std::exception &ex) { + // 貌似flash从来都不用schema1 [AUTO-TRANSLATED:2c6d140f] + // It seems that flash never uses schema1 + //WarnL << "try rtmp complex schema0 failed:" << ex.what(); + try { + /* c1s1 schema1 + time: 4bytes + version: 4bytes + digest: 764bytes + key: 764bytes + */ + digest = get_C1_digest((uint8_t *) schema_start, &digest_start); + string s1_joined(s1_start, C1_HANDSHARK_SIZE); + s1_joined.erase(digest_start - s1_start, C1_DIGEST_SIZE); + check_S1_Digest(digest, s1_joined); + //send_complex_S0S1S2(1, digest); + //InfoL << "schema1"; + } catch (std::exception &ex) { + WarnL << "try rtmp complex schema1 failed:" << ex.what(); + return; + } + } + + //InfoL << "send complex C2"; + auto c2_key = openssl_HMACsha256(FPKey, sizeof(FPKey), digest.data(), digest.size()); + std::string c2_str((char*)(&c2), sizeof(c2)- C1_DIGEST_SIZE); + auto c2_digest = openssl_HMACsha256(c2_key.data(), c2_key.size(), c2_str.data(), c2_str.size()); + memcpy(c2.random + RANDOM_LEN - C1_DIGEST_SIZE, c2_digest.data(), C1_DIGEST_SIZE); +} void RtmpProtocol::check_C1_Digest(const string &digest,const string &data){ auto sha256 = openssl_HMACsha256(FPKey, C1_FPKEY_SIZE, data.data(), data.size()); diff --git a/src/Rtmp/RtmpProtocol.h b/src/Rtmp/RtmpProtocol.h index 7a12495c..be51bbb1 100644 --- a/src/Rtmp/RtmpProtocol.h +++ b/src/Rtmp/RtmpProtocol.h @@ -63,14 +63,16 @@ protected: void sendRtmp(uint8_t type, uint32_t stream_index, const std::string &buffer, uint32_t stamp, int chunk_id); void sendRtmp(uint8_t type, uint32_t stream_index, const toolkit::Buffer::Ptr &buffer, uint32_t stamp, int chunk_id); toolkit::BufferRaw::Ptr obtainBuffer(const void *data = nullptr, size_t len = 0); - + private: void handle_C1_simple(const char *data); #ifdef ENABLE_OPENSSL + void handle_S1_complex(const char *data, RtmpHandshake &c2); void handle_C1_complex(const char *data); std::string get_C1_digest(const uint8_t *ptr,char **digestPos); std::string get_C1_key(const uint8_t *ptr); void check_C1_Digest(const std::string &digest,const std::string &data); + void check_S1_Digest(const std::string &digest,const std::string &data); void send_complex_S0S1S2(int schemeType,const std::string &digest); #endif //ENABLE_OPENSSL diff --git a/src/Rtmp/RtmpPusher.cpp b/src/Rtmp/RtmpPusher.cpp index badcd01c..2032b3ec 100644 --- a/src/Rtmp/RtmpPusher.cpp +++ b/src/Rtmp/RtmpPusher.cpp @@ -332,6 +332,12 @@ void RtmpPusher::onRtmpChunk(RtmpPacket::Ptr packet) { } } +size_t RtmpPusher::getSendSpeed() { + return TcpClient::getSendSpeed(); +} +size_t RtmpPusher::getSendTotalBytes() { + return TcpClient::getSendTotalBytes(); +} } /* namespace mediakit */ diff --git a/src/Rtmp/RtmpPusher.h b/src/Rtmp/RtmpPusher.h index 41caac49..878fa8f4 100644 --- a/src/Rtmp/RtmpPusher.h +++ b/src/Rtmp/RtmpPusher.h @@ -27,6 +27,9 @@ public: void publish(const std::string &url) override ; void teardown() override; + size_t getSendSpeed() override; + size_t getSendTotalBytes() override; + protected: //for Tcpclient override void onRecv(const toolkit::Buffer::Ptr &buf) override; diff --git a/src/Rtmp/RtmpSession.cpp b/src/Rtmp/RtmpSession.cpp index 6fc3c216..89e36f87 100644 --- a/src/Rtmp/RtmpSession.cpp +++ b/src/Rtmp/RtmpSession.cpp @@ -96,6 +96,8 @@ void RtmpSession::onCmd_connect(AMFDecoder &dec) { // 赋值rtmp app _media_info.app = params["app"].as_string(); + _media_info.protocol = overSsl() ? "rtmps" : "rtmp"; + bool ok = true; //(app == APP_NAME); AMFValue version(AMF_OBJECT); version.set("fmsVer", "FMS/3,0,1,123"); @@ -306,7 +308,7 @@ void RtmpSession::sendPlayResponse(const string &err, const RtmpMediaSource::Ptr weak_ptr weak_self = static_pointer_cast(shared_from_this()); _ring_reader->setGetInfoCB([weak_self]() { Any ret; - ret.set(static_pointer_cast(weak_self.lock())); + ret.set(static_pointer_cast(weak_self.lock())); return ret; }); _ring_reader->setReadCB([weak_self](const RtmpMediaSource::RingDataType &pkt) { @@ -589,9 +591,7 @@ void RtmpSession::onSendMedia(const RtmpPacket::Ptr &pkt) { } bool RtmpSession::close(MediaSource &sender) { - //此回调在其他线程触发 - string err = StrPrinter << "close media: " << sender.getUrl(); - safeShutdown(SockException(Err_shutdown, err)); + shutdown(SockException(Err_shutdown, "close media: " + sender.getUrl())); return true; } diff --git a/src/Rtp/Decoder.cpp b/src/Rtp/Decoder.cpp index 61cbdd7c..42abe783 100644 --- a/src/Rtp/Decoder.cpp +++ b/src/Rtp/Decoder.cpp @@ -11,6 +11,7 @@ #include "Decoder.h" #include "PSDecoder.h" #include "TSDecoder.h" +#include "Common/config.h" #include "Extension/Factory.h" #if defined(ENABLE_RTPPROXY) || defined(ENABLE_HLS) @@ -122,11 +123,18 @@ void DecoderImp::onDecode(int stream, int codecid, int flags, int64_t pts, int64 WarnL << "Unsupported codec :" << getCodecName(codec); return; } + GET_CONFIG(bool, merge_frame, RtpProxy::kMergeFrame) auto frame = Factory::getFrameFromPtr(codec, (char *)data, bytes, dts, pts); - if (getTrackType(codec) != TrackVideo) { + if (getTrackType(codec) != TrackVideo || !merge_frame) { onFrame(stream, frame); + if (_last_is_keyframe && _video_merge) { + // 上次是关键帧,收到音频后,说明帧收齐了 + _video_merge->flush(); + } return; } + _last_is_keyframe = frame->keyFrame() || frame->configFrame(); + _video_merge = &ref.second; ref.second.inputFrame(frame, [this, stream, codec](uint64_t dts, uint64_t pts, const Buffer::Ptr &buffer, bool) { onFrame(stream, Factory::getFrameFromBuffer(codec, buffer, dts, pts)); }); @@ -143,7 +151,7 @@ void DecoderImp::onTrack(int index, const Track::Ptr &track) { track->setIndex(index); auto &ref = _tracks[index]; if (ref.first) { - WarnL << "Already existed a same track: " << index << ", codec: " << track->getCodecName(); + // WarnL << "Already existed a same track: " << index << ", codec: " << track->getCodecName(); return; } ref.first = track; diff --git a/src/Rtp/Decoder.h b/src/Rtp/Decoder.h index 4902157c..8a8b90a1 100644 --- a/src/Rtp/Decoder.h +++ b/src/Rtp/Decoder.h @@ -59,13 +59,14 @@ private: private: bool _finished = false; bool _have_video = false; + bool _last_is_keyframe = false; Decoder::Ptr _decoder; MediaSinkInterface *_sink; - class FrameMergerImp : public FrameMerger { public: FrameMergerImp() : FrameMerger(FrameMerger::none) {} }; + FrameMergerImp *_video_merge = nullptr; std::unordered_map > _tracks; }; diff --git a/src/Rtp/PSDecoder.cpp b/src/Rtp/PSDecoder.cpp index 21014b4b..d7961869 100644 --- a/src/Rtp/PSDecoder.cpp +++ b/src/Rtp/PSDecoder.cpp @@ -56,7 +56,7 @@ ssize_t PSDecoder::input(const uint8_t *data, size_t bytes) { const char *PSDecoder::onSearchPacketTail(const char *data, size_t len) { try { auto ret = ps_demuxer_input(static_cast(_ps_demuxer), reinterpret_cast(data), len); - if (ret >= 0) { + if (ret >= 0 && ret <= (ssize_t)len) { // 解析成功全部或部分 [AUTO-TRANSLATED:a8085d34] // Parse successful, all or part return data + ret; diff --git a/src/Rtp/RawEncoder.cpp b/src/Rtp/RawEncoder.cpp index c14254bd..8c58935d 100644 --- a/src/Rtp/RawEncoder.cpp +++ b/src/Rtp/RawEncoder.cpp @@ -13,7 +13,7 @@ #include "RawEncoder.h" #include "Extension/Factory.h" #include "Rtsp/RtspMuxer.h" -#include "Common//config.h" +#include "Common/config.h" using namespace toolkit; diff --git a/src/Rtp/RtpProcess.cpp b/src/Rtp/RtpProcess.cpp index 61079161..a5312d73 100644 --- a/src/Rtp/RtpProcess.cpp +++ b/src/Rtp/RtpProcess.cpp @@ -107,6 +107,7 @@ bool RtpProcess::inputRtp(bool is_udp, const Socket::Ptr &sock, const char *data if (!_auth_err.empty()) { throw toolkit::SockException(toolkit::Err_other, _auth_err); } + auto header = (RtpHeader *) data; if (_sock != sock) { // 第一次运行本函数 [AUTO-TRANSLATED:a1d7ac17] // First time running this function @@ -114,7 +115,7 @@ bool RtpProcess::inputRtp(bool is_udp, const Socket::Ptr &sock, const char *data _sock = sock; _addr.reset(new sockaddr_storage(*((sockaddr_storage *)addr))); if (first) { - emitOnPublish(); + emitOnPublish(ntohl(header->ssrc)); _cache_ticker.resetTime(); } } @@ -127,10 +128,10 @@ bool RtpProcess::inputRtp(bool is_udp, const Socket::Ptr &sock, const char *data fwrite((uint8_t *) data, len, 1, _save_file_rtp.get()); } if (!_process) { + _media_info.protocol = is_udp ? "udp" : "tcp"; _process = std::make_shared(_media_info, this); } - auto header = (RtpHeader *) data; onRtp(ntohs(header->seq), ntohl(header->stamp), 0/*不发送sr,所以可以设置为0*/ , 90000/*ps/ts流时间戳按照90K采样率*/, len); GET_CONFIG(string, dump_dir, RtpProxy::kDumpDir); @@ -202,27 +203,24 @@ void RtpProcess::doCachedFunc() { } bool RtpProcess::alive() { - if (_stop_rtp_check.load()) { - if(_last_check_alive.elapsedTime() > 5 * 60 * 1000){ - // 最多暂停5分钟的rtp超时检测,因为NAT映射有效期一般不会太长 [AUTO-TRANSLATED:2df59aad] - // Pause the RTP timeout detection for a maximum of 5 minutes, because the NAT mapping validity period is generally not very long. - _stop_rtp_check = false; - } else { + if (_pause_timeout) { + if (_last_check_alive.elapsedTime() < _pause_seconds * 1000) { return true; } + // 最多暂停_pause_seconds秒的rtp超时检测,因为NAT映射有效期一般不会太长 + _pause_timeout = false; } _last_check_alive.resetTime(); GET_CONFIG(uint64_t, timeoutSec, RtpProxy::kTimeoutSec) - if (_last_frame_time.elapsedTime() / 1000 < timeoutSec) { - return true; - } - return false; + return _last_frame_time.elapsedTime() < timeoutSec * 1000; } -void RtpProcess::setStopCheckRtp(bool is_check){ - _stop_rtp_check = is_check; - if (!is_check) { +void RtpProcess::pauseRtpTimeout(bool pause, uint32_t pause_seconds) { + _pause_timeout = pause; + // 默认5分钟恢复超时监测 + _pause_seconds = pause_seconds ? pause_seconds : 300; + if (!pause) { _last_frame_time.resetTime(); } } @@ -270,15 +268,15 @@ string RtpProcess::getIdentifier() const { return _media_info.stream; } -void RtpProcess::emitOnPublish() { +void RtpProcess::emitOnPublish(uint32_t ssrc) { weak_ptr weak_self = shared_from_this(); - Broadcast::PublishAuthInvoker invoker = [weak_self](const string &err, const ProtocolOption &option) { + Broadcast::PublishAuthInvoker invoker = [weak_self, ssrc](const string &err, const ProtocolOption &option) { auto strong_self = weak_self.lock(); if (!strong_self) { return; } auto poller = strong_self->getOwnerPoller(MediaSource::NullMediaSource()); - poller->async([weak_self, err, option]() { + poller->async([weak_self, err, option, ssrc]() { auto strong_self = weak_self.lock(); if (!strong_self) { return; @@ -292,7 +290,7 @@ void RtpProcess::emitOnPublish() { } strong_self->_muxer->setMediaListener(strong_self); strong_self->doCachedFunc(); - InfoP(strong_self) << "允许RTP推流"; + InfoP(strong_self) << "允许RTP推流,ssrc: " << printSSRC(ssrc); } else { strong_self->_auth_err = err; WarnP(strong_self) << "禁止RTP推流:" << err; diff --git a/src/Rtp/RtpProcess.h b/src/Rtp/RtpProcess.h index eecdf694..08b91e5e 100644 --- a/src/Rtp/RtpProcess.h +++ b/src/Rtp/RtpProcess.h @@ -69,12 +69,11 @@ public: void setOnDetach(onDetachCB cb); /** - * 设置onDetach事件回调,false检查RTP超时,true停止 - * Set onDetach event callback, false checks RTP timeout, true stops - - * [AUTO-TRANSLATED:2780397f] + * 暂停或恢复rtp超时监测 + * @param pause 是否暂停超时检测 + * @param pause_seconds 暂停超时检测最大时间(单位秒),超过这个时间后将恢复超时检测; 设置为0时默认为300 */ - void setStopCheckRtp(bool is_check=false); + void pauseRtpTimeout(bool pause, uint32_t pause_seconds = 0); /** * 设置为单track,单音频/单视频时可以加快媒体注册速度 @@ -122,17 +121,19 @@ protected: private: RtpProcess(const MediaTuple &tuple); - void emitOnPublish(); + void emitOnPublish(uint32_t ssrc); void doCachedFunc(); bool alive(); void onManager(); void createTimer(); private: - OnlyTrack _only_track = kAll; - std::string _auth_err; + bool _pause_timeout = false; + uint32_t _pause_seconds = 5 * 60; uint64_t _dts = 0; uint64_t _total_bytes = 0; + OnlyTrack _only_track = kAll; + std::string _auth_err; std::unique_ptr _addr; toolkit::Socket::Ptr _sock; MediaInfo _media_info; @@ -142,7 +143,6 @@ private: std::shared_ptr _save_file_video; ProcessInterface::Ptr _process; MultiMediaSourceMuxer::Ptr _muxer; - std::atomic_bool _stop_rtp_check{false}; toolkit::Timer::Ptr _timer; toolkit::Ticker _last_check_alive; std::recursive_mutex _func_mtx; diff --git a/src/Rtp/RtpSender.cpp b/src/Rtp/RtpSender.cpp index 88f06961..3ae929e7 100644 --- a/src/Rtp/RtpSender.cpp +++ b/src/Rtp/RtpSender.cpp @@ -35,16 +35,25 @@ RtpSender::~RtpSender() { } } -void RtpSender::startSend(const MediaSourceEvent::SendRtpArgs &args, const function &cb){ +void RtpSender::startSend(const MediaSourceEvent &sender, const MediaSourceEvent::SendRtpArgs &args, const function &cb){ + auto origin_socket = sender.getOriginSock(MediaSource::NullMediaSource()); + _origin_socket = dynamic_pointer_cast(origin_socket); + if (!_origin_socket) { + auto process = dynamic_pointer_cast(origin_socket); + if (process) { + _origin_socket = process->getSock(); + } + } + _args = args; if (!_interface) { // 重连时不重新创建对象 [AUTO-TRANSLATED:b788cd5d] // Do not recreate the object when reconnecting auto lam = [this](std::shared_ptr> list) { onFlushRtpList(std::move(list)); }; switch (args.data_type) { - case MediaSourceEvent::SendRtpArgs::kRtpPS: _interface = std::make_shared(lam, atoi(args.ssrc.data()), args.pt, true); break; - case MediaSourceEvent::SendRtpArgs::kRtpTS: _interface = std::make_shared(lam, atoi(args.ssrc.data()), args.pt, false); break; - case MediaSourceEvent::SendRtpArgs::kRtpES: _interface = std::make_shared(lam, atoi(args.ssrc.data()), args.pt, args.only_audio); break; + case MediaSourceEvent::SendRtpArgs::kRtpPS: _interface = std::make_shared(lam, stoll(args.ssrc), args.pt, true); break; + case MediaSourceEvent::SendRtpArgs::kRtpTS: _interface = std::make_shared(lam, stoll(args.ssrc), args.pt, false); break; + case MediaSourceEvent::SendRtpArgs::kRtpES: _interface = std::make_shared(lam, stoll(args.ssrc), args.pt, args.only_audio); break; default: CHECK(0, "invalid rtp type: " + to_string(args.data_type)); break; } } @@ -313,6 +322,15 @@ void RtpSender::onConnect() { } }); } + + if (_socket_rtp->sockType() == toolkit::SockNum::Sock_TCP && _origin_socket) { + // rtp 端口是TCP端口,转发速度应当控制收流速度 + auto origin_socket = _origin_socket; + _socket_rtp->setOnFlush([origin_socket]() { + origin_socket->enableRecv(true); + return true; + }); + } InfoL << "startSend rtp success: " << _socket_rtp->get_peer_ip() << ":" << _socket_rtp->get_peer_port() << ", data_type: " << _args.data_type << ", con_type: " << _args.con_type; } @@ -368,7 +386,7 @@ void RtpSender::onSendRtpUdp(const toolkit::Buffer::Ptr &buf, bool check) { _rtcp_send_ticker.resetTime(); // rtcp ssrc为rtp ssrc + 1 [AUTO-TRANSLATED:318fada3] // rtcp ssrc is rtp ssrc + 1 - auto sr = _rtcp_context->createRtcpSR(atoi(_args.ssrc.data()) + 1); + auto sr = _rtcp_context->createRtcpSR(stoll(_args.ssrc) + 1); // send sender report rtcp _socket_rtcp->send(sr); } @@ -433,6 +451,9 @@ void RtpSender::onFlushRtpList(shared_ptr> rtp_list) { } default: CHECK(0); } + if (_args.enable_origin_recv_limit && _socket_rtp->sockType() == toolkit::SockNum::Sock_TCP && _socket_rtp->isSocketBusy() && _origin_socket) { + _origin_socket->enableRecv(false); + } }); }; if (_args.con_type != MediaSourceEvent::SendRtpArgs::kVoiceTalk) { @@ -457,5 +478,49 @@ void RtpSender::setOnClose(std::function _on_close = std::move(on_close); } +size_t RtpSender::getSendSpeed() const { + size_t ret = 0; + if (_socket_rtp) { + ret += _socket_rtp->getSendSpeed(); + } + if (_socket_rtcp) { + ret += _socket_rtcp->getSendSpeed(); + } + return ret; +} + +size_t RtpSender::getRecvSpeed() const { + size_t ret = 0; + if (_socket_rtp) { + ret += _socket_rtp->getRecvSpeed(); + } + if (_socket_rtcp) { + ret += _socket_rtcp->getRecvSpeed(); + } + return ret; +} + +size_t RtpSender::getRecvTotalBytes() const { + size_t ret = 0; + if (_socket_rtp) { + ret += _socket_rtp->getRecvTotalBytes(); + } + if (_socket_rtcp) { + ret += _socket_rtcp->getRecvTotalBytes(); + } + return ret; +} + +size_t RtpSender::getSendTotalBytes() const { + size_t ret = 0; + if (_socket_rtp) { + ret += _socket_rtp->getSendTotalBytes(); + } + if (_socket_rtcp) { + ret += _socket_rtcp->getSendTotalBytes(); + } + return ret; +} + } // namespace mediakit #endif // defined(ENABLE_RTPPROXY) diff --git a/src/Rtp/RtpSender.h b/src/Rtp/RtpSender.h index dbdf4a1c..8ed3eb80 100644 --- a/src/Rtp/RtpSender.h +++ b/src/Rtp/RtpSender.h @@ -40,7 +40,7 @@ public: * [AUTO-TRANSLATED:c31bd9b3] */ - void startSend(const MediaSourceEvent::SendRtpArgs &args, const std::function &cb); + void startSend(const MediaSourceEvent &sender, const MediaSourceEvent::SendRtpArgs &args, const std::function &cb); /** * 输入帧数据 @@ -94,6 +94,11 @@ public: */ void setOnClose(std::function on_close); + size_t getSendSpeed() const; + size_t getRecvSpeed() const; + size_t getRecvTotalBytes() const; + size_t getSendTotalBytes() const; + private: // 合并写输出 [AUTO-TRANSLATED:23544836] // Merge write output @@ -111,6 +116,7 @@ private: private: bool _is_connect = false; + toolkit::Socket::Ptr _origin_socket; MediaSourceEvent::SendRtpArgs _args; toolkit::Socket::Ptr _socket_rtp; toolkit::Socket::Ptr _socket_rtcp; diff --git a/src/Rtp/RtpServer.cpp b/src/Rtp/RtpServer.cpp index 580f8159..8aaaf494 100644 --- a/src/Rtp/RtpServer.cpp +++ b/src/Rtp/RtpServer.cpp @@ -62,7 +62,12 @@ public: RtpProcess::Ptr getProcess() const { return _process; } void onRecvRtp(const Socket::Ptr &sock, const Buffer::Ptr &buf, struct sockaddr *addr) { - _process->inputRtp(true, sock, buf->data(), buf->size(), addr); + try { + _process->inputRtp(true, sock, buf->data(), buf->size(), addr); + } catch (std::exception &ex) { + _process->onDetach(SockException(Err_shutdown, ex.what())); + return; + } // 统计rtp接受情况,用于发送rr包 [AUTO-TRANSLATED:bd2fbe7e] // Count RTP reception status, used to send RR packets auto header = (RtpHeader *)buf->data(); @@ -202,8 +207,8 @@ void RtpServer::start(uint16_t local_port, const char *local_ip, const MediaTupl TcpServer::Ptr tcp_server; if (tcp_mode == PASSIVE || tcp_mode == ACTIVE) { auto processor = helper ? helper->getProcess() : nullptr; - // 如果共享同一个processor对象,那么tcp server深圳为单线程模式确保线程安全 [AUTO-TRANSLATED:68bdd877] - // If the same processor object is shared, then the TCP server Shenzhen is in single-threaded mode to ensure thread safety + // 如果共享同一个processor对象,那么tcp server声明为单线程模式确保线程安全 [AUTO-TRANSLATED:68bdd877] + // If the same processor object is shared, declare the TCP server in single-threaded mode to ensure thread safety. tcp_server = std::make_shared(processor ? poller : nullptr); (*tcp_server)[RtpSession::kVhost] = tuple.vhost; (*tcp_server)[RtpSession::kApp] = tuple.app; diff --git a/src/Rtsp/RtpReceiver.h b/src/Rtsp/RtpReceiver.h index 4c9d10c8..8b9e4ba0 100644 --- a/src/Rtsp/RtpReceiver.h +++ b/src/Rtsp/RtpReceiver.h @@ -188,8 +188,14 @@ private: } iterator popIterator(iterator it) { - output(it->first, std::move(it->second)); - return _pkt_sort_cache_map.erase(it); + try { + output(it->first, std::move(it->second)); + return _pkt_sort_cache_map.erase(it); + } catch (...) { + // 防止抛异常未移除迭代器,导致rtp包为空 + _pkt_sort_cache_map.erase(it); + throw; + } } void output(SEQ seq, T packet) { diff --git a/src/Rtsp/Rtsp.cpp b/src/Rtsp/Rtsp.cpp index 27c7eb6f..7787d635 100644 --- a/src/Rtsp/Rtsp.cpp +++ b/src/Rtsp/Rtsp.cpp @@ -161,7 +161,7 @@ string SdpTrack::toString(uint16_t port) const { } default: break; } - return std::move(_printer); + return _printer; } static TrackType toTrackType(const string &str) { @@ -235,15 +235,25 @@ void SdpParser::load(const string &sdp) { auto &track = *track_ptr; auto it = track._attr.find("range"); if (it != track._attr.end()) { - char name[16] = { 0 }, start[16] = { 0 }, end[16] = { 0 }; - int ret = sscanf(it->second.data(), "%15[^=]=%15[^-]-%15s", name, start, end); + char name[16] = { 0 }, start[17] = { 0 }, end[17] = { 0 }; + int ret = sscanf(it->second.data(), "%15[^=]=%16[^-]-%16s", name, start, end); if (3 == ret || 2 == ret) { - if (strcmp(start, "now") == 0) { - strcpy(start, "0"); + // 保存 range 类型 + track._range_type = name; + if (strcmp(name, "clock") == 0) { + // clock 格式:clock=20251123T000000Z-20251124T000000Z + track._range_start_str = start; + track._range_end_str = end; + // 对于 clock 格式,不解析为数值 + } else { + // npt 格式或其他格式 + if (strcmp(start, "now") == 0) { + strcpy(start, "0"); + } + track._start = (float)atof(start); + track._end = (float)atof(end); + track._duration = track._end - track._start; } - track._start = (float)atof(start); - track._end = (float)atof(end); - track._duration = track._end - track._start; } } @@ -644,7 +654,7 @@ string RtpHeader::dumpString(size_t rtp_size) const { printer << "rtp size:" << rtp_size << "\r\n"; printer << "payload offset:" << getPayloadOffset() << "\r\n"; printer << "payload size:" << getPayloadSize(rtp_size) << "\r\n"; - return std::move(printer); + return printer; } /////////////////////////////////////////////////////////////////////// @@ -745,7 +755,7 @@ TitleSdp::TitleSdp(float dur_sec, const std::map &head } DefaultSdp::DefaultSdp(int payload_type, const Track &track) - : Sdp(track.getTrackType() == TrackVideo ? 9000 : static_cast(track).getAudioSampleRate(), payload_type) { + : Sdp(track.getTrackType() == TrackVideo ? 90000 : static_cast(track).getAudioSampleRate(), payload_type) { _printer << "m=" << track.getTrackTypeStr() << " 0 RTP/AVP " << payload_type << "\r\n"; auto bitrate = track.getBitRate() >> 10; if (bitrate) { diff --git a/src/Rtsp/Rtsp.h b/src/Rtsp/Rtsp.h index f94664c2..4f9cd137 100644 --- a/src/Rtsp/Rtsp.h +++ b/src/Rtsp/Rtsp.h @@ -53,7 +53,7 @@ typedef enum { XX(JPEG, TrackVideo, 26, 90000, 1, CodecJPEG) \ XX(nv, TrackVideo, 28, 90000, 1, CodecInvalid) \ XX(H261, TrackVideo, 31, 90000, 1, CodecInvalid) \ - XX(MPV, TrackVideo, 32, 90000, 1, CodecInvalid) \ + XX(MPV, TrackVideo, 32, 90000, 1, CodecMP2V) \ XX(MP2T, TrackVideo, 33, 90000, 1, CodecTS) \ XX(H263, TrackVideo, 34, 90000, 1, CodecInvalid) @@ -237,6 +237,9 @@ public: float _duration = 0; float _start = 0; float _end = 0; + std::string _range_type; // 新增:保存 range 类型,如 "npt" 或 "clock" + std::string _range_start_str; // 新增:保存原始 range start 字符串(用于 clock 格式) + std::string _range_end_str; // 新增:保存原始 range end 字符串(用于 clock 格式) std::map _other; std::multimap _attr; diff --git a/src/Rtsp/RtspMediaSource.h b/src/Rtsp/RtspMediaSource.h index ae4c6022..b0f617bd 100644 --- a/src/Rtsp/RtspMediaSource.h +++ b/src/Rtsp/RtspMediaSource.h @@ -132,7 +132,7 @@ public: * [AUTO-TRANSLATED:24b0ee74] */ - virtual uint16_t getSeqence(TrackType trackType) { + virtual uint16_t getSequence(TrackType trackType) { assert(trackType >= 0 && trackType < TrackMax); auto &track = _tracks[trackType]; if (!track) { diff --git a/src/Rtsp/RtspPlayer.cpp b/src/Rtsp/RtspPlayer.cpp index 6a2cea5c..dbf36af4 100644 --- a/src/Rtsp/RtspPlayer.cpp +++ b/src/Rtsp/RtspPlayer.cpp @@ -21,6 +21,12 @@ #include #include #include +#include +#include + +#if defined(_WIN32) +#include "Util/strptime_win.h" +#endif using namespace toolkit; using namespace std; @@ -91,22 +97,22 @@ void RtspPlayer::play(const string &strUrl) { _speed = (*this)[Client::kRtspSpeed].as(); DebugL << url._url << " " << (url._user.size() ? url._user : "null") << " " << (url._passwd.size() ? url._passwd : "null") << " " << _rtp_type; - weak_ptr weakSelf = static_pointer_cast(shared_from_this()); + weak_ptr weak_self = static_pointer_cast(shared_from_this()); float playTimeOutSec = (*this)[Client::kTimeoutMS].as() / 1000.0f; - _play_check_timer.reset(new Timer( - playTimeOutSec, - [weakSelf]() { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) { - return false; - } - strongSelf->onPlayResult_l(SockException(Err_timeout, "play rtsp timeout"), false); - return false; - }, - getPoller())); + _play_check_timer.reset(new Timer(playTimeOutSec,[weak_self]() { + if (auto strong_self = weak_self.lock()) { + strong_self->onPlayResult_l(SockException(Err_timeout, "play rtsp timeout"), false); + } + return false; + }, getPoller())); - if (!(*this)[Client::kNetAdapter].empty()) { - setNetAdapter((*this)[Client::kNetAdapter]); + auto &adapter = (*this)[Client::kNetAdapter]; + if (!adapter.empty()) { + setNetAdapter(std::move(adapter)); + } + auto &custom_header = (*this)[Client::kCustomHeader]; + if (!custom_header.empty()) { + _custom_header = mediakit::Parser::parseArgs(custom_header); } startConnect(url._host, url._port, playTimeOutSec); } @@ -174,12 +180,12 @@ bool RtspPlayer::handleAuthenticationFailure(const string ¶msStr) { return false; } -bool RtspPlayer::handleResponse(const string &cmd, const Parser &parser) { +bool RtspPlayer::handleResponse(const std::string &cmd, const Parser &parser, send_method_handler handler) { string authInfo = parser["WWW-Authenticate"]; // 发送DESCRIBE命令后的回复 [AUTO-TRANSLATED:39629cf0] // The response after sending the DESCRIBE command if ((parser.status() == "401") && handleAuthenticationFailure(authInfo)) { - sendOptions(); + (this->*handler)(); return false; } if (parser.status() == "302" || parser.status() == "301") { @@ -197,7 +203,7 @@ bool RtspPlayer::handleResponse(const string &cmd, const Parser &parser) { } void RtspPlayer::handleResDESCRIBE(const Parser &parser) { - if (!handleResponse("DESCRIBE", parser)) { + if (!handleResponse("DESCRIBE", parser, &RtspPlayer::sendDescribe)) { return; } _content_base = parser["Content-Base"]; @@ -212,6 +218,20 @@ void RtspPlayer::handleResDESCRIBE(const Parser &parser) { // Parse SDP SdpParser sdpParser(parser.content()); + // 保存 range 信息(从第一个 track 获取) + auto tracks = sdpParser.getAvailableTrack(); + if (!tracks.empty()) { + auto title_track = sdpParser.getTrack(TrackTitle); + if (title_track && !title_track->_range_type.empty()) { + _range_type = title_track->_range_type; + _range_start_str = title_track->_range_start_str; + _range_end_str = title_track->_range_end_str; + } else if (!tracks.empty() && !tracks[0]->_range_type.empty()) { + _range_type = tracks[0]->_range_type; + _range_start_str = tracks[0]->_range_start_str; + _range_end_str = tracks[0]->_range_end_str; + } + } _control_url = sdpParser.getControlUrl(_content_base); string sdp; @@ -265,7 +285,7 @@ void RtspPlayer::sendSetup(unsigned int track_idx) { case Rtsp::RTP_TCP: { sendRtspRequest( "SETUP", control_url, - { "Transport", StrPrinter << "RTP/AVP/TCP;unicast;interleaved=" << track->_type * 2 << "-" << track->_type * 2 + 1 << ";mode=play" }); + { "Transport", StrPrinter << "RTP/AVP/TCP;unicast;interleaved=" << track_idx * 2 << "-" << track_idx * 2 + 1 << ";mode=play" }); } break; case Rtsp::RTP_MULTICAST: { sendRtspRequest("SETUP", control_url, { "Transport", "RTP/AVP;multicast;mode=play" }); @@ -411,12 +431,11 @@ void RtspPlayer::handleResSETUP(const Parser &parser, unsigned int track_idx) { // All SETUP commands have been sent // 发送play命令 [AUTO-TRANSLATED:47a826d1] // Send PLAY command - if (_speed==0.0f) { + if (_speed == 0.0f) { sendPause(type_play, 0); } else { sendPause(type_speed, 0); } - } void RtspPlayer::sendDescribe() { @@ -428,7 +447,7 @@ void RtspPlayer::sendDescribe() { void RtspPlayer::sendOptions() { _on_response = [this](const Parser &parser) { - if (!handleResponse("OPTIONS", parser)) { + if (!handleResponse("OPTIONS", parser, &RtspPlayer::sendOptions)) { return; } // 获取服务器支持的命令 [AUTO-TRANSLATED:8a6a12f1] @@ -447,7 +466,11 @@ void RtspPlayer::sendOptions() { } void RtspPlayer::sendKeepAlive() { - _on_response = [](const Parser &parser) {}; + if (_play_check_timer) + { + WarnL << "receive RTP packet before handleResPAUSE"; + } + _on_keepalive_reponse = [](const Parser &parser) {}; if (_supported_cmd.find("GET_PARAMETER") != _supported_cmd.end()) { // 支持GET_PARAMETER,用此命令保活 [AUTO-TRANSLATED:b45cd737] // Support GET_PARAMETER, use this command to keep alive @@ -465,15 +488,53 @@ void RtspPlayer::sendPause(int type, uint32_t seekMS) { // Start or pause RTSP switch (type) { case type_pause: sendRtspRequest("PAUSE", _control_url, {}); break; - case type_play: - // sendRtspRequest("PLAY", _content_base); - // break; - case type_seek: - sendRtspRequest("PLAY", _control_url, { "Range", StrPrinter << "npt=" << setiosflags(ios::fixed) << setprecision(2) << seekMS / 1000.0 << "-" }); - break; - case type_speed: - speed(_speed); - break; + case type_play: sendRtspRequest("PLAY", _content_base); break; + case type_seek: { + std::string range_header; + if (_range_type == "clock" && !_range_start_str.empty()) { + // clock 格式:需要计算新的时间 + // 解析起始时间:20251123T000000Z + struct tm tm_start; + const char *start_str = _range_start_str.c_str(); + if (strptime(start_str, "%Y%m%dT%H%M%SZ", &tm_start) != nullptr) { + // 转换为 time_t,加上 seekMS 毫秒 +#if defined(_WIN32) + time_t start_time = _mkgmtime(&tm_start); +#else + time_t start_time = timegm(&tm_start); +#endif + start_time += seekMS / 1000; // 加上秒数 + + // 格式化新的时间 + struct tm tm_new; +#if defined(_WIN32) + auto gmtime_ret = gmtime_s(&tm_new, &start_time); + if (gmtime_ret == 0) +#else + auto gmtime_ret = gmtime_r(&start_time, &tm_new); + if (gmtime_ret != nullptr) +#endif + { + char new_time[32]; + strftime(new_time, sizeof(new_time), "%Y%m%dT%H%M%SZ", &tm_new); + + // 构建 Range 头 + range_header = StrPrinter << "clock=" << new_time << "-" << _range_end_str; + } else { + // 解析失败,回退到 npt 格式 + range_header = StrPrinter << "npt=" << setiosflags(ios::fixed) << setprecision(2) << seekMS / 1000.0 << "-"; + } + } else { + // 解析失败,回退到 npt 格式 + range_header = StrPrinter << "npt=" << setiosflags(ios::fixed) << setprecision(2) << seekMS / 1000.0 << "-"; + } + } else { + // npt 格式或其他格式 + range_header = StrPrinter << "npt=" << setiosflags(ios::fixed) << setprecision(2) << seekMS / 1000.0 << "-"; + } + sendRtspRequest("PLAY", _control_url, { "Range", range_header }); + } break; + case type_speed: speed(_speed); break; default: WarnL << "unknown type : " << type; _on_response = nullptr; @@ -489,6 +550,10 @@ void RtspPlayer::speed(float speed) { sendRtspRequest("PLAY", _control_url, { "Scale", StrPrinter << speed }); } +void RtspPlayer::seekTo(uint32_t pos) { + seekToMilliSecond(pos * 1000); +} + void RtspPlayer::handleResPAUSE(const Parser &parser, int type) { if (parser.status() != "200") { switch (type) { @@ -537,6 +602,10 @@ void RtspPlayer::onWholeRtspPacket(Parser &parser) { try { decltype(_on_response) func; _on_response.swap(func); + if (!func) + { + _on_keepalive_reponse.swap(func); + } if (func) { func(parser); } @@ -552,7 +621,9 @@ void RtspPlayer::onRtpPacket(const char *data, size_t len) { int trackIdx = -1; uint8_t interleaved = data[1]; if (interleaved % 2 == 0) { - trackIdx = getTrackIndexByInterleaved(interleaved); + CHECK(len > RtpPacket::kRtpHeaderSize + RtpPacket::kRtpTcpHeaderSize); + RtpHeader *header = (RtpHeader *)(data + RtpPacket::kRtpTcpHeaderSize); + trackIdx = getTrackIndexByPT(header->pt); if (trackIdx == -1) { return; } @@ -585,6 +656,9 @@ void RtspPlayer::onRtcpPacket(int track_idx, SdpTrack::Ptr &track, uint8_t *data void RtspPlayer::onRtpSorted(RtpPacket::Ptr rtppt, int trackidx) { _stamp[trackidx] = rtppt->getStampMS(); + if (!_first_stamp[trackidx]) { + _first_stamp[trackidx] = _stamp[trackidx]; + } _rtp_recv_ticker.resetTime(); onRecvRTP(std::move(rtppt), _sdp_track[trackidx]); } @@ -612,7 +686,7 @@ float RtspPlayer::getPacketLossRate(TrackType type) const { } uint32_t RtspPlayer::getProgressMilliSecond() const { - return MAX(_stamp[0], _stamp[1]); + return MAX(_stamp[0] - _first_stamp[0], _stamp[1] - _first_stamp[1]); } void RtspPlayer::seekToMilliSecond(uint32_t ms) { @@ -630,7 +704,6 @@ void RtspPlayer::sendRtspRequest(const string &cmd, const string &url, const std key = val; } } - sendRtspRequest(cmd, url, header_map); } @@ -689,9 +762,18 @@ void RtspPlayer::sendRtspRequest(const string &cmd, const string &url, const Str printer << cmd << " " << url << " RTSP/1.0\r\n"; TraceL << cmd << " "<< url; + + if (cmd == "PLAY") { + // play命令时支持覆盖更新rtsp头,用于onvif点播等场景 + for (auto &pr : _custom_header) { + header[pr.first] = pr.second; + } + } + for (auto &pr : header) { printer << pr.first << ": " << pr.second << "\r\n"; } + printer << "\r\n"; SockSender::send(std::move(printer)); } @@ -798,12 +880,25 @@ void RtspPlayer::onPlayResult_l(const SockException &ex, bool handshake_done) { }; // 创建rtp数据接收超时检测定时器 [AUTO-TRANSLATED:edbffc19] // Create RTP data receive timeout detection timer - _rtp_check_timer = std::make_shared(timeoutMS / 2000.0f, lam, getPoller()); + _rtp_check_timer = std::make_shared(timeoutMS / 2000.0f, std::move(lam), getPoller()); } else { sendTeardown(); } } +int RtspPlayer::getTrackIndexByPT(int pt) const { + for (size_t i = 0; i < _sdp_track.size(); ++i) { + if (_sdp_track[i]->_pt == pt) { + return i; + } + } + if (_sdp_track.size() == 1) { + return 0; + } + WarnL << "no such track with pt:" << pt; + return -1; +} + int RtspPlayer::getTrackIndexByInterleaved(int interleaved) const { for (size_t i = 0; i < _sdp_track.size(); ++i) { if (_sdp_track[i]->_interleaved == interleaved) { @@ -829,6 +924,36 @@ int RtspPlayer::getTrackIndexByTrackType(TrackType track_type) const { throw SockException(Err_other, StrPrinter << "no such track with type:" << getTrackString(track_type)); } +size_t RtspPlayer::getRecvSpeed() { + size_t ret = TcpClient::getRecvSpeed(); + for (auto &rtp : _rtp_sock) { + if (rtp) { + ret += rtp->getRecvSpeed(); + } + } + for (auto &rtcp : _rtcp_sock) { + if (rtcp) { + ret += rtcp->getRecvSpeed(); + } + } + return ret; +} + +size_t RtspPlayer::getRecvTotalBytes() { + size_t ret = TcpClient::getRecvTotalBytes(); + for (auto &rtp : _rtp_sock) { + if (rtp) { + ret += rtp->getRecvTotalBytes(); + } + } + for (auto &rtcp : _rtcp_sock) { + if (rtcp) { + ret += rtcp->getRecvTotalBytes(); + } + } + return ret; +} + /////////////////////////////////////////////////// // RtspPlayerImp float RtspPlayerImp::getDuration() const { diff --git a/src/Rtsp/RtspPlayer.h b/src/Rtsp/RtspPlayer.h index 58e2cfa5..6103d8f9 100644 --- a/src/Rtsp/RtspPlayer.h +++ b/src/Rtsp/RtspPlayer.h @@ -36,9 +36,13 @@ public: void play(const std::string &strUrl) override; void pause(bool pause) override; void speed(float speed) override; + void seekTo(uint32_t pos) override; // 新增 void teardown() override; float getPacketLossRate(TrackType type) const override; + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; + protected: // 派生类回调函数 [AUTO-TRANSLATED:61e20903] // Derived class callback function @@ -117,6 +121,7 @@ protected: private: void onPlayResult_l(const toolkit::SockException &ex , bool handshake_done); + int getTrackIndexByPT(int pt) const; int getTrackIndexByInterleaved(int interleaved) const; int getTrackIndexByTrackType(TrackType track_type) const; @@ -124,7 +129,8 @@ private: void handleResDESCRIBE(const Parser &parser); bool handleAuthenticationFailure(const std::string &wwwAuthenticateParamsStr); void handleResPAUSE(const Parser &parser, int type); - bool handleResponse(const std::string &cmd, const Parser &parser); + using send_method_handler = void (RtspPlayer::*)(void); + bool handleResponse(const std::string &cmd, const Parser &parser, send_method_handler handler); void sendOptions(); void sendSetup(unsigned int track_idx); @@ -154,9 +160,11 @@ private: std::string _play_url; // rtsp开始倍速 [AUTO-TRANSLATED:9ab84508] // Rtsp start speed - float _speed= 0.0f; + float _speed = 0.0f; std::vector _sdp_track; std::function _on_response; + std::function _on_keepalive_reponse; + protected: // RTP端口,trackid idx 为数组下标 [AUTO-TRANSLATED:77c186bb] // RTP port, trackid idx is the array subscript toolkit::Socket::Ptr _rtp_sock[2]; @@ -164,6 +172,7 @@ private: // RTCP port, trackid idx is the array subscript toolkit::Socket::Ptr _rtcp_sock[2]; +private: // rtsp鉴权相关 [AUTO-TRANSLATED:947dc6a3] // Rtsp authentication related std::string _md5_nonce; @@ -173,11 +182,21 @@ private: uint32_t _cseq_send = 1; std::string _content_base; std::string _control_url; + + std::string _range_type; // 新增:保存 range 类型 + std::string _range_start_str; // 新增:保存 clock 格式的起始时间 + std::string _range_end_str; // 新增:保存 clock 格式的结束时间 + +protected: Rtsp::eRtpType _rtp_type = Rtsp::RTP_TCP; +private: + // 起始时间戳 + uint64_t _first_stamp[2] = {0, 0}; + // 当前rtp时间戳 [AUTO-TRANSLATED:410f2691] // Current rtp timestamp - uint32_t _stamp[2] = {0, 0}; + uint64_t _stamp[2] = {0, 0}; // 超时功能实现 [AUTO-TRANSLATED:1d603b3a] // Timeout function implementation @@ -194,6 +213,8 @@ private: // 统计rtp并发送rtcp [AUTO-TRANSLATED:0ac2b665] // Statistics rtp and send rtcp std::vector _rtcp_context; + // 用户自定义rtsp头 + StrCaseMap _custom_header; }; } /* namespace mediakit */ diff --git a/src/Rtsp/RtspPlayerImp.h b/src/Rtsp/RtspPlayerImp.h index 0bb7b474..4f65e582 100644 --- a/src/Rtsp/RtspPlayerImp.h +++ b/src/Rtsp/RtspPlayerImp.h @@ -51,7 +51,13 @@ public: } void seekTo(uint32_t seekPos) override { - uint32_t pos = MAX(float(0), MIN(seekPos, getDuration())) * 1000; + uint32_t pos = seekPos * 1000; + // 如果是点播流(有时长),限制在有效范围内 + // If it's a VOD stream (has duration), limit to valid range + float duration = getDuration(); + if (duration > 0) { + pos = MAX(float(0), MIN(seekPos, getDuration())) * 1000; + } seekToMilliSecond(pos); } diff --git a/src/Rtsp/RtspPusher.cpp b/src/Rtsp/RtspPusher.cpp index 7b644363..ba75fb26 100644 --- a/src/Rtsp/RtspPusher.cpp +++ b/src/Rtsp/RtspPusher.cpp @@ -277,8 +277,8 @@ void RtspPusher::sendSetup(unsigned int track_idx) { switch (_rtp_type) { case Rtsp::RTP_TCP: { sendRtspRequest("SETUP", control_url, {"Transport", - StrPrinter << "RTP/AVP/TCP;unicast;interleaved=" << track->_type * 2 - << "-" << track->_type * 2 + 1 << ";mode=record"}); + StrPrinter << "RTP/AVP/TCP;unicast;interleaved=" << track_idx * 2 + << "-" << track_idx * 2 + 1 << ";mode=record"}); } break; case Rtsp::RTP_UDP: { @@ -595,5 +595,34 @@ void RtspPusher::sendRtspRequest(const string &cmd, const string &url,const StrC SockSender::send(std::move(printer)); } +size_t RtspPusher::getSendSpeed() { + size_t ret = TcpClient::getSendSpeed(); + for (auto &rtp : _rtp_sock) { + if (rtp) { + ret += rtp->getSendSpeed(); + } + } + for (auto &rtcp : _rtcp_sock) { + if (rtcp) { + ret += rtcp->getSendSpeed(); + } + } + return ret; +} + +size_t RtspPusher::getSendTotalBytes() { + size_t ret = TcpClient::getSendTotalBytes(); + for (auto &rtp : _rtp_sock) { + if (rtp) { + ret += rtp->getSendTotalBytes(); + } + } + for (auto &rtcp : _rtcp_sock) { + if (rtcp) { + ret += rtcp->getSendTotalBytes(); + } + } + return ret; +} } /* namespace mediakit */ diff --git a/src/Rtsp/RtspPusher.h b/src/Rtsp/RtspPusher.h index eb929f34..ca2c962e 100644 --- a/src/Rtsp/RtspPusher.h +++ b/src/Rtsp/RtspPusher.h @@ -30,6 +30,8 @@ public: ~RtspPusher() override; void publish(const std::string &url) override; void teardown() override; + size_t getSendSpeed() override; + size_t getSendTotalBytes() override; protected: //for Tcpclient override diff --git a/src/Rtsp/RtspSession.cpp b/src/Rtsp/RtspSession.cpp index 6fa88b7e..d0367970 100644 --- a/src/Rtsp/RtspSession.cpp +++ b/src/Rtsp/RtspSession.cpp @@ -134,6 +134,7 @@ void RtspSession::onWholeRtspPacket(Parser &parser) { _content_base = rtsp._url; _media_info.parse(parser.fullUrl()); _media_info.schema = RTSP_SCHEMA; + _media_info.protocol = overSsl() ? "rtsps" : "rtsp"; } using rtsp_request_handler = void (RtspSession::*)(const Parser &parser); @@ -166,7 +167,9 @@ void RtspSession::onWholeRtspPacket(Parser &parser) { void RtspSession::onRtpPacket(const char *data, size_t len) { uint8_t interleaved = data[1]; if (interleaved % 2 == 0) { - auto track_idx = getTrackIndexByInterleaved(interleaved); + CHECK(len > RtpPacket::kRtpHeaderSize + RtpPacket::kRtpTcpHeaderSize); + RtpHeader *header = (RtpHeader *)(data + RtpPacket::kRtpTcpHeaderSize); + auto track_idx = getTrackIndexByPT(header->pt); handleOneRtp(track_idx, _sdp_track[track_idx]->_type, _sdp_track[track_idx]->_samplerate, (uint8_t *) data + RtpPacket::kRtpTcpHeaderSize, len - RtpPacket::kRtpTcpHeaderSize); } else { auto track_idx = getTrackIndexByInterleaved(interleaved - 1); @@ -206,6 +209,7 @@ void RtspSession::handleReq_ANNOUNCE(const Parser &parser) { //去除.sdp后缀,防止EasyDarwin推流器强制添加.sdp后缀 full_url = full_url.substr(0, full_url.length() - 4); _media_info.parse(full_url); + _media_info.protocol = overSsl() ? "rtsps" : "rtsp"; } if (_media_info.app.empty() || _media_info.stream.empty()) { @@ -433,7 +437,7 @@ void RtspSession::onAuthSuccess() { strong_self->_play_src = rtsp_src; for(auto &track : strong_self->_sdp_track){ track->_ssrc = rtsp_src->getSsrc(track->_type); - track->_seq = rtsp_src->getSeqence(track->_type); + track->_seq = rtsp_src->getSequence(track->_type); track->_time_stamp = rtsp_src->getTimeStamp(track->_type); } @@ -828,7 +832,7 @@ void RtspSession::handleReq_Play(const Parser &parser) { } inited_tracks.emplace_back(track->_type); track->_ssrc = play_src->getSsrc(track->_type); - track->_seq = play_src->getSeqence(track->_type); + track->_seq = play_src->getSequence(track->_type); track->_time_stamp = play_src->getTimeStamp(track->_type); rtp_info << "url=" << track->getControlUrl(_content_base) << ";" @@ -859,7 +863,7 @@ void RtspSession::handleReq_Play(const Parser &parser) { _play_reader = play_src->getRing()->attach(getPoller(), use_gop); _play_reader->setGetInfoCB([weak_self]() { Any ret; - ret.set(static_pointer_cast(weak_self.lock())); + ret.set(static_pointer_cast(weak_self.lock())); return ret; }); _play_reader->setDetachCB([weak_self]() { @@ -1122,6 +1126,18 @@ bool RtspSession::sendRtspResponse(const string &res_code, const std::initialize return sendRtspResponse(res_code,header_map,sdp,protocol); } +int RtspSession::getTrackIndexByPT(int pt) const { + for (size_t i = 0; i < _sdp_track.size(); ++i) { + if (pt == _sdp_track[i]->_pt) { + return i; + } + } + if (_sdp_track.size() == 1) { + return 0; + } + throw SockException(Err_shutdown, StrPrinter << "no such track with pt:" << pt); +} + int RtspSession::getTrackIndexByTrackType(TrackType type) { for (size_t i = 0; i < _sdp_track.size(); ++i) { if (type == _sdp_track[i]->_type) { @@ -1159,9 +1175,7 @@ int RtspSession::getTrackIndexByInterleaved(int interleaved) { } bool RtspSession::close(MediaSource &sender) { - //此回调在其他线程触发 - string err = StrPrinter << "close media: " << sender.getUrl(); - safeShutdown(SockException(Err_shutdown,err)); + shutdown(SockException(Err_shutdown,"close media: " + sender.getUrl())); return true; } diff --git a/src/Rtsp/RtspSession.h b/src/Rtsp/RtspSession.h index f985fd3b..a9c80060 100644 --- a/src/Rtsp/RtspSession.h +++ b/src/Rtsp/RtspSession.h @@ -153,6 +153,7 @@ private: void send_NotAcceptable(); // 获取track下标 [AUTO-TRANSLATED:36d0b2c2] // Get the track index + int getTrackIndexByPT(int pt) const; int getTrackIndexByTrackType(TrackType type); int getTrackIndexByControlUrl(const std::string &control_url); int getTrackIndexByInterleaved(int interleaved); diff --git a/src/Shell/ShellCMD.h b/src/Shell/ShellCMD.h index 3df4ee50..4b889dc2 100644 --- a/src/Shell/ShellCMD.h +++ b/src/Shell/ShellCMD.h @@ -36,9 +36,9 @@ public: if (!media) { break; } - if (!media->close(true)) { - break; - } + media->getOwnerPoller()->async([media]() { + media->close(true); + }); (*stream) << "\t踢出成功:" << media->getUrl() << "\r\n"; return; } while (0); diff --git a/srt/Ack.cpp b/srt/Ack.cpp index 44249cb4..a7c799b1 100644 --- a/srt/Ack.cpp +++ b/srt/Ack.cpp @@ -80,6 +80,6 @@ std::string ACKPacket::dump() { << " rtt_variance=" << rtt_variance << " pkt_recv_rate=" << pkt_recv_rate << " available_buf_size=" << available_buf_size << " estimated_link_capacity=" << estimated_link_capacity << " recv_rate=" << recv_rate; - return std::move(printer); + return printer; } } // namespace SRT \ No newline at end of file diff --git a/srt/CMakeLists.txt b/srt/CMakeLists.txt index fcbdbd8a..417c6535 100644 --- a/srt/CMakeLists.txt +++ b/srt/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/srt/Crypto.cpp b/srt/Crypto.cpp index 40c29f8f..798809a4 100644 --- a/srt/Crypto.cpp +++ b/srt/Crypto.cpp @@ -49,11 +49,11 @@ inline const EVP_CIPHER* aes_key_len_mapping_ctr_cipher(int key_len) { static bool aes_wrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len) { #if defined(ENABLE_OPENSSL) - EVP_CIPHER_CTX* ctx = NULL; + EVP_CIPHER_CTX* ctx = NULL; *outLen = 0; - do { + do { if (!(ctx = EVP_CIPHER_CTX_new())) { WarnL << "EVP_CIPHER_CTX_new fail"; break; @@ -62,29 +62,29 @@ static bool aes_wrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, u if (1 != EVP_EncryptInit_ex(ctx, aes_key_len_mapping_wrap_cipher(key_len), NULL, key, NULL)) { WarnL << "EVP_EncryptInit_ex fail"; - break; - } + break; + } - int len1 = 0; - if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { + int len1 = 0; + if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { WarnL << "EVP_EncryptUpdate fail"; - break; - } + break; + } - int len2 = 0; - if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { + int len2 = 0; + if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { WarnL << "EVP_EncryptFinal_ex fail"; - break; - } + break; + } - *outLen = len1 + len2; - } while (0); + *outLen = len1 + len2; + } while (0); - if (ctx != NULL) { - EVP_CIPHER_CTX_free(ctx); - } + if (ctx != NULL) { + EVP_CIPHER_CTX_free(ctx); + } - return *outLen != 0; + return *outLen != 0; #else return false; #endif @@ -103,11 +103,11 @@ static bool aes_wrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, u static bool aes_unwrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len) { #if defined(ENABLE_OPENSSL) - EVP_CIPHER_CTX* ctx = NULL; + EVP_CIPHER_CTX* ctx = NULL; *outLen = 0; - do { + do { if (!(ctx = EVP_CIPHER_CTX_new())) { WarnL << "EVP_CIPHER_CTX_new fail"; @@ -117,8 +117,8 @@ static bool aes_unwrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_wrap_cipher(key_len), NULL, key, NULL)) { WarnL << "EVP_DecryptInit_ex fail"; - break; - } + break; + } //设置pkcs7padding if (1 != EVP_CIPHER_CTX_set_padding(ctx, 1)) { @@ -126,26 +126,26 @@ static bool aes_unwrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, break; } - int len1 = 0; - if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { + int len1 = 0; + if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { WarnL << "EVP_DecryptUpdate fail"; - break; - } + break; + } - int len2 = 0; - if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { + int len2 = 0; + if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { WarnL << "EVP_DecryptFinal_ex fail"; - break; - } + break; + } - *outLen = len1 + len2; - } while (0); + *outLen = len1 + len2; + } while (0); - if (ctx != NULL) { - EVP_CIPHER_CTX_free(ctx); - } + if (ctx != NULL) { + EVP_CIPHER_CTX_free(ctx); + } - return *outLen != 0; + return *outLen != 0; #else return false; @@ -166,11 +166,11 @@ static bool aes_unwrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, static bool aes_ctr_encrypt(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len, uint8_t* iv) { #if defined(ENABLE_OPENSSL) - EVP_CIPHER_CTX* ctx = NULL; + EVP_CIPHER_CTX* ctx = NULL; *outLen = 0; - do { + do { if (!(ctx = EVP_CIPHER_CTX_new())) { WarnL << "EVP_CIPHER_CTX_new fail"; break; @@ -178,29 +178,29 @@ static bool aes_ctr_encrypt(const uint8_t* in, int in_len, uint8_t* out, int* ou if (1 != EVP_EncryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) { WarnL << "EVP_EncryptInit_ex fail"; - break; - } + break; + } - int len1 = 0; - if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { + int len1 = 0; + if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { WarnL << "EVP_EncryptUpdate fail"; - break; - } + break; + } - int len2 = 0; - if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { + int len2 = 0; + if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { WarnL << "EVP_EncryptFinal_ex fail"; - break; - } + break; + } - *outLen = len1 + len2; - } while (0); + *outLen = len1 + len2; + } while (0); - if (ctx != NULL) { - EVP_CIPHER_CTX_free(ctx); - } + if (ctx != NULL) { + EVP_CIPHER_CTX_free(ctx); + } - return *outLen != 0; + return *outLen != 0; #else return false; #endif @@ -221,42 +221,42 @@ static bool aes_ctr_encrypt(const uint8_t* in, int in_len, uint8_t* out, int* ou static bool aes_ctr_decrypt(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len, uint8_t* iv) { #if defined(ENABLE_OPENSSL) - EVP_CIPHER_CTX* ctx = NULL; + EVP_CIPHER_CTX* ctx = NULL; *outLen = 0; - do { + do { if (!(ctx = EVP_CIPHER_CTX_new())) { WarnL << "EVP_CIPHER_CTX_new fail"; break; } - if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) { + if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) { WarnL << "EVP_DecryptInit_ex fail"; - break; - } + break; + } - int len1 = 0; - if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { + int len1 = 0; + if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) { WarnL << "EVP_DecryptUpdate fail"; - break; - } + break; + } - int len2 = 0; - if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { + int len2 = 0; + if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) { WarnL << "EVP_DecryptFinal_ex fail"; - break; - } + break; + } - *outLen = len1 + len2; - } while (0); + *outLen = len1 + len2; + } while (0); - if (ctx != NULL) { - EVP_CIPHER_CTX_free(ctx); - } + if (ctx != NULL) { + EVP_CIPHER_CTX_free(ctx); + } - return *outLen != 0; + return *outLen != 0; #else return false; diff --git a/srt/HSExt.cpp b/srt/HSExt.cpp index 45be1a12..826678cd 100644 --- a/srt/HSExt.cpp +++ b/srt/HSExt.cpp @@ -34,7 +34,7 @@ std::string HSExtMessage::dump() { _StrPrinter printer; printer << "srt version : " << std::hex << srt_version << " srt flag : " << std::hex << srt_flag << " recv_tsbpd_delay=" << recv_tsbpd_delay << " send_tsbpd_delay = " << send_tsbpd_delay; - return std::move(printer); + return printer; } bool HSExtMessage::storeToData() { @@ -83,8 +83,10 @@ bool HSExtStreamID::loadFromData(uint8_t *buf, size_t len) { ptr += 4; } char zero = 0x00; - if (streamid.back() == zero) { - streamid.erase(streamid.find_first_of(zero), streamid.size()); + if (!streamid.empty()) { + if (streamid.back() == zero) { + streamid.erase(streamid.find_first_of(zero), streamid.size()); + } } return true; } @@ -128,7 +130,7 @@ bool HSExtStreamID::storeToData() { std::string HSExtStreamID::dump() { _StrPrinter printer; printer << " streamid : " << streamid; - return std::move(printer); + return printer; } size_t KeyMaterial::getContentSize() { @@ -261,7 +263,7 @@ std::string KeyMaterial::dump() { << " sLen : " << _slen << " salt : " << std::hex << _salt.data() << " kLen : " << _klen; - return std::move(printer); + return printer; } bool HSExtKeyMaterial::loadFromData(uint8_t *buf, size_t len) { diff --git a/srt/Packet.cpp b/srt/Packet.cpp index edfbe36d..f3f55ca0 100644 --- a/srt/Packet.cpp +++ b/srt/Packet.cpp @@ -279,7 +279,7 @@ std::string HandshakePacket::dump(){ for(size_t i=0;idump()<<"\r\n"; } - return std::move(printer); + return printer; } bool HandshakePacket::loadFromData(uint8_t *buf, size_t len) { if (HEADER_SIZE + HS_CONTENT_MIN_SIZE > len) { @@ -353,7 +353,7 @@ bool HandshakePacket::loadExtMessage(uint8_t *buf, size_t len) { case HSExt::SRT_CMD_SID: ext = std::make_shared(); break; case HSExt::SRT_CMD_KMREQ: case HSExt::SRT_CMD_KMRSP: - ext = std::make_shared(); break; + ext = std::make_shared(); break; default: WarnL << "not support ext " << type; break; } if (ext) { @@ -627,7 +627,7 @@ std::string NAKPacket::dump() { for (auto it : lost_list) { printer << "[ " << it.first << " , " << it.second - 1 << " ]"; } - return std::move(printer); + return printer; } bool MsgDropReqPacket::loadFromData(uint8_t *buf, size_t len) { diff --git a/srt/PacketQueue.cpp b/srt/PacketQueue.cpp index ce1da952..ecf191dc 100644 --- a/srt/PacketQueue.cpp +++ b/srt/PacketQueue.cpp @@ -233,7 +233,7 @@ std::string PacketQueue::dump() { printer << " last:" << _pkt_map.rbegin()->second->packet_seq_number; printer << " latency:" << timeLatency() / 1e3; } - return std::move(printer); + return printer; } //////////////////// PacketRecvQueue ////////////////////////////////// @@ -395,7 +395,7 @@ std::string PacketRecvQueue::dump() { printer << " start:" << _start; printer << " end:" << _end; } - return std::move(printer); + return printer; } bool PacketRecvQueue::drop(uint32_t first, uint32_t last, std::list &out) { uint32_t diff = 0; diff --git a/src/Srt/SrtCaller.cpp b/srt/SrtCaller.cpp similarity index 82% rename from src/Srt/SrtCaller.cpp rename to srt/SrtCaller.cpp index 2a84c8c3..3871e2ed 100644 --- a/src/Srt/SrtCaller.cpp +++ b/srt/SrtCaller.cpp @@ -21,57 +21,42 @@ using namespace SRT; namespace mediakit { +//zlm play format //srt://127.0.0.1:9000?streamid=#!::r=live/test //srt://127.0.0.1:9000?streamid=#!::r=live/test,h=__defaultVhost__ +//zlm push format +//srt://127.0.0.1:9000?streamid=#!::r=live/test,m=publish +//srt://127.0.0.1:9000?streamid=#!::r=live/test,h=__defaultVhost__,m=publish void SrtUrl::parse(const string &strUrl) { - //DebugL << "url: " << strUrl; + //DebugL << "url: " << strUrl; _full_url = strUrl; auto url = strUrl; auto ip = findSubString(url.data(), "://", "?"); splitUrl(ip, _host, _port); - auto _params = findSubString(url.data(), "?" , NULL); + if (!SockUtil::getDomainIP(_host.c_str(), _port, _addr, AF_INET, SOCK_DGRAM, IPPROTO_UDP)) { + throw std::invalid_argument("invalid host: " + _host); + } + + auto _params = findSubString(url.data(), "?" , NULL); auto kv = Parser::parseArgs(_params); auto it = kv.find("streamid"); - if (it != kv.end()) { - auto streamid = it->second; - if (!toolkit::start_with(streamid, "#!::")) { - return; - } - std::string real_streamid = streamid.substr(4); + if (it != kv.end()) { + auto streamid = it->second; + if (!toolkit::start_with(streamid, "#!::")) { + return; + } + _streamid = streamid; + } - auto params = Parser::parseArgs(real_streamid, ",", "="); - - for (auto iit : params) { - if (iit.first == "h") { - _vhost = iit.second; - } else if (iit.first == "r") { - auto tmps = toolkit::split(iit.second, "/"); - if (tmps.size() < 2) { - continue; - } - _app = tmps[0]; - _stream = tmps[1]; - } else { - //nop - } - } - - if (_vhost.empty()) { - _vhost = DEFAULT_VHOST; - } - } - - //TraceL << "ip: " << ip; - //TraceL << "_host: " << _host; - //TraceL << "_port: " << _port; - //TraceL << "_params: " << _params; - //TraceL << "_vhost: " << _vhost; - //TraceL << "_app: " << _app; - //TraceL << "_stream: " << _stream; - return; + //TraceL << "ip: " << ip; + //TraceL << "_host: " << _host; + //TraceL << "_port: " << _port; + //TraceL << "_params: " << _params; + //TraceL << "_streamid: " << _streamid; + return; } @@ -79,10 +64,10 @@ void SrtUrl::parse(const string &strUrl) { SrtCaller::SrtCaller(const toolkit::EventPoller::Ptr &poller) { _poller = poller ? std::move(poller) : EventPollerPool::Instance().getPoller(); _start_timestamp = SteadyClock::now(); - _socket_id = generateSocketId(); + _socket_id = generateSocketId(); - /* _init_seq_number = generateInitSeq(); */ - _init_seq_number = 0; + /* _init_seq_number = generateInitSeq(); */ + _init_seq_number = 0; _last_pkt_seq = _init_seq_number - 1; _pkt_recv_rate_context = std::make_shared(_start_timestamp); @@ -93,16 +78,15 @@ SrtCaller::SrtCaller(const toolkit::EventPoller::Ptr &poller) { } SrtCaller::~SrtCaller(void) { - DebugL; + DebugL; } void SrtCaller::onConnect() { - //DebugL; + //DebugL; - auto peer_addr = SockUtil::make_sockaddr(_url._host.c_str(), (_url._port)); - _socket = Socket::createSocket(_poller, false); - _socket->bindUdpSock(0); - _socket->bindPeerAddr((struct sockaddr *)&peer_addr, 0, true); + _socket = Socket::createSocket(_poller, false); + _socket->bindUdpSock(0, _url._addr.ss_family == AF_INET ? "0.0.0.0" : "::"); + _socket->bindPeerAddr((struct sockaddr *)&_url._addr, 0, true); weak_ptr weak_self = shared_from_this(); _socket->setOnRead([weak_self](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) mutable { @@ -110,10 +94,10 @@ void SrtCaller::onConnect() { if (!strong_self) { return; } - strong_self->inputSockData((uint8_t*)buf->data(), buf->size(), addr); + strong_self->inputSockData((uint8_t*)buf->data(), buf->size(), addr); }); - doHandshake(); + doHandshake(); } void SrtCaller::onResult(const SockException &ex) { @@ -138,7 +122,7 @@ void SrtCaller::onResult(const SockException &ex) { void SrtCaller::onHandShakeFinished() { DebugL; - _is_handleshake_finished = true; + _is_handleshake_finished = true; if (_handleshake_timer) { _handleshake_timer.reset(); } @@ -165,7 +149,7 @@ void SrtCaller::onHandShakeFinished() { void SrtCaller::onSRTData(DataPacket::Ptr pkt) { InfoL; if (!isPlayer()) { - WarnL << "this is not a player data ignore"; + WarnL << "this is not a player data ignore"; return; } } @@ -215,7 +199,7 @@ void SrtCaller::onSendTSData(const Buffer::Ptr &buffer, bool flush) { } void SrtCaller::inputSockData(uint8_t *buf, int len, struct sockaddr *addr) { - //TraceL << hexdump((void*)buf, len); + //TraceL << hexdump((void*)buf, len); using srt_control_handler = void (SrtCaller::*)(uint8_t * buf, int len, struct sockaddr *addr); static std::unordered_map s_control_functions; @@ -237,8 +221,8 @@ void SrtCaller::inputSockData(uint8_t *buf, int len, struct sockaddr *addr) { // 处理srt数据 if (DataPacket::isDataPacket(buf, len)) { - uint32_t socketId = DataPacket::getSocketID(buf, len); - if (isPlayer()) { + if (_is_handleshake_finished && isPlayer()) { + uint32_t socketId = DataPacket::getSocketID(buf, len); if (socketId == _socket_id) { _pkt_recv_rate_context->inputPacket(_now, len + UDP_HDR_SIZE); handleDataPacket(buf, len, addr); @@ -277,16 +261,16 @@ void SrtCaller::doHandshake() { _crypto = std::make_shared(getPassphrase()); } - sendHandshakeInduction(); + sendHandshakeInduction(); return; } void SrtCaller::sendHandshakeInduction() { - DebugL; + DebugL; _induction_ts = SteadyClock::now(); - SRT::HandshakePacket::Ptr req = std::make_shared(); - req->timestamp = DurationCountMicroseconds(_induction_ts - _start_timestamp); + SRT::HandshakePacket::Ptr req = std::make_shared(); + req->timestamp = DurationCountMicroseconds(_induction_ts - _start_timestamp); req->dst_socket_id = 0; req->version = 4; @@ -299,11 +283,10 @@ void SrtCaller::sendHandshakeInduction() { req->srt_socket_id = _socket_id; req->syn_cookie = 0; - auto dataSenderAddr = SockUtil::make_sockaddr(_url._host.c_str(), _url._port); - req->assignPeerIPBE(&dataSenderAddr); + req->assignPeerIPBE(&_url._addr); req->storeToData(); - _handleshake_req = req; - sendControlPacket(req, true); + _handleshake_req = req; + sendControlPacket(req, true); std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); _handleshake_timer = std::make_shared(0.2, [weak_self]()->bool{ @@ -323,10 +306,10 @@ void SrtCaller::sendHandshakeInduction() { } void SrtCaller::sendHandshakeConclusion() { - DebugL; + DebugL; - SRT::HandshakePacket::Ptr req = std::make_shared(); - req->timestamp = DurationCountMicroseconds(_now - _start_timestamp); + SRT::HandshakePacket::Ptr req = std::make_shared(); + req->timestamp = DurationCountMicroseconds(_now - _start_timestamp); req->dst_socket_id = 0; req->version = 5; @@ -345,13 +328,12 @@ void SrtCaller::sendHandshakeConclusion() { req->srt_socket_id = _socket_id; req->syn_cookie = _sync_cookie; - auto addr = SockUtil::make_sockaddr(_url._host.c_str(), _url._port); - req->assignPeerIPBE(&addr); + req->assignPeerIPBE(&_url._addr); - HSExtMessage::Ptr ext = std::make_shared(); - ext->extension_type = HSExt::SRT_CMD_HSREQ; - ext->srt_version = srtVersion(1, 5, 0); - ext->srt_flag = 0xbf; + HSExtMessage::Ptr ext = std::make_shared(); + ext->extension_type = HSExt::SRT_CMD_HSREQ; + ext->srt_version = srtVersion(1, 5, 0); + ext->srt_flag = 0xbf; // if set latency, use set value _delay = getLatency(); @@ -364,13 +346,13 @@ void SrtCaller::sendHandshakeConclusion() { } } - ext->recv_tsbpd_delay = _delay; - ext->send_tsbpd_delay = _delay; - req->ext_list.push_back(std::move(ext)); + ext->recv_tsbpd_delay = _delay; + ext->send_tsbpd_delay = _delay; + req->ext_list.push_back(std::move(ext)); - HSExtStreamID::Ptr extStreamId = std::make_shared(); - extStreamId->streamid = generateStreamId(); - req->ext_list.push_back(std::move(extStreamId)); + HSExtStreamID::Ptr extStreamId = std::make_shared(); + extStreamId->streamid = generateStreamId(); + req->ext_list.push_back(std::move(extStreamId)); if (_crypto) { HSExtKeyMaterial::Ptr keyMaterial = _crypto->generateKeyMaterialExt(HSExt::SRT_CMD_KMREQ); @@ -378,8 +360,8 @@ void SrtCaller::sendHandshakeConclusion() { } req->storeToData(); - _handleshake_req = req; - sendControlPacket(req); + _handleshake_req = req; + sendControlPacket(req); return; } @@ -491,7 +473,7 @@ void SrtCaller::sendMsgDropReq(uint32_t first, uint32_t last) { void SrtCaller::sendKeepLivePacket() { auto now = SteadyClock::now(); - SRT::KeepLivePacket::Ptr req = std::make_shared(); + SRT::KeepLivePacket::Ptr req = std::make_shared(); req->timestamp = SRT::DurationCountMicroseconds(now - _start_timestamp); req->dst_socket_id = _peer_socket_id; req->storeToData(); @@ -510,7 +492,7 @@ void SrtCaller::sendShutDown() { } void SrtCaller::tryAnnounceKeyMaterial() { - //TraceL; + //TraceL; if (!_crypto) { return; @@ -546,9 +528,9 @@ void SrtCaller::tryAnnounceKeyMaterial() { } void SrtCaller::sendControlPacket(SRT::ControlPacket::Ptr pkt, bool flush) { - //TraceL; + //TraceL; sendPacket(pkt, flush); - return; + return; } void SrtCaller::sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, bool flush) { @@ -571,22 +553,22 @@ void SrtCaller::sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, boo pkt->storeToData((uint8_t *)data, size); sendPacket(pkt, flush); _send_buf->inputPacket(pkt); - return; + return; } void SrtCaller::sendPacket(Buffer::Ptr pkt, bool flush) { - //TraceL << pkt->size(); + //TraceL << pkt->size(); auto tmp = _packet_pool.obtain2(); tmp->assign(pkt->data(), pkt->size()); - _socket->send(std::move(tmp), nullptr, 0, flush); + _socket->send(std::move(tmp), nullptr, 0, flush); _send_ticker.resetTime(); - return; + return; } void SrtCaller::handleHandshake(uint8_t *buf, int len, struct sockaddr *addr) { - //DebugL; - SRT::HandshakePacket pkt; + //DebugL; + SRT::HandshakePacket pkt; if(!pkt.loadFromData(buf, len)){ WarnL<< "is not vaild HandshakePacket"; return; @@ -610,96 +592,96 @@ void SrtCaller::handleHandshake(uint8_t *buf, int len, struct sockaddr *addr) { } void SrtCaller::handleHandshakeInduction(SRT::HandshakePacket &pkt, struct sockaddr *addr) { - DebugL; + DebugL; - if (!_handleshake_req) { - WarnL << "must Induction Phase for handleshake"; - return; - } + if (!_handleshake_req) { + WarnL << "must Induction Phase for handleshake"; + return; + } - if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_CONCLUSION) { - WarnL << "should be Conclusion Phase for handleshake "; - return; - } else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_INDUCTION) { - WarnL <<"not reach this"; - return; - } + if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_CONCLUSION) { + WarnL << "should be Conclusion Phase for handleshake "; + return; + } else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_INDUCTION) { + WarnL <<"not reach this"; + return; + } - // Induction Phase + // Induction Phase if (pkt.version != 5) { - WarnL << "not support handleshake version: " << pkt.version; - return; - } + WarnL << "not support handleshake version: " << pkt.version; + return; + } - if (pkt.extension_field != 0x4A17) { - WarnL << "not match SRT MAGIC"; - return; - } + if (pkt.extension_field != 0x4A17) { + WarnL << "not match SRT MAGIC"; + return; + } - if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) { - WarnL << "not match _socket_id"; - return; - } + if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) { + WarnL << "not match _socket_id"; + return; + } // TODO: encryption_field - _sync_cookie = pkt.syn_cookie; + _sync_cookie = pkt.syn_cookie; _mtu = std::min(pkt.mtu, _mtu); _max_flow_window_size = std::min(pkt.max_flow_window_size, _max_flow_window_size); - sendHandshakeConclusion(); + sendHandshakeConclusion(); return; } void SrtCaller::handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sockaddr *addr) { - DebugL; + DebugL; - if (!_handleshake_req) { - WarnL << "must Conclusion Phase for handleshake "; - return; - } + if (!_handleshake_req) { + WarnL << "must Conclusion Phase for handleshake "; + return; + } - if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) { - WarnL << "should be Conclusion Phase for handleshake "; - return; - } else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_CONCLUSION) { - WarnL <<"not reach this"; - return; - } + if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) { + WarnL << "should be Conclusion Phase for handleshake "; + return; + } else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_CONCLUSION) { + WarnL <<"not reach this"; + return; + } - // Conclusion Phase + // Conclusion Phase if (pkt.version != 5) { - WarnL << "not support handleshake version: " << pkt.version; - return; - } + WarnL << "not support handleshake version: " << pkt.version; + return; + } - if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) { - WarnL << "not match _socket_id"; - return; - } + if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) { + WarnL << "not match _socket_id"; + return; + } // TODO: encryption_field - _peer_socket_id = pkt.srt_socket_id; + _peer_socket_id = pkt.srt_socket_id; - HSExtMessage::Ptr resp; + HSExtMessage::Ptr resp; HSExtKeyMaterial::Ptr keyMaterial; - for (auto& ext : pkt.ext_list) { - if (!resp) { - resp = std::dynamic_pointer_cast(ext); - } + for (auto& ext : pkt.ext_list) { + if (!resp) { + resp = std::dynamic_pointer_cast(ext); + } if (!keyMaterial) { keyMaterial = std::dynamic_pointer_cast(ext); } - } + } - if (resp) { + if (resp) { _delay = std::max(_delay, resp->recv_tsbpd_delay); - //DebugL << "flag " << resp->srt_flag; - //DebugL << "recv_tsbpd_delay " << resp->recv_tsbpd_delay; - //DebugL << "send_tsbpd_delay " << resp->send_tsbpd_delay; - } + //DebugL << "flag " << resp->srt_flag; + //DebugL << "recv_tsbpd_delay " << resp->recv_tsbpd_delay; + //DebugL << "send_tsbpd_delay " << resp->send_tsbpd_delay; + } if (keyMaterial && _crypto) { _crypto->loadFromKeyMaterial(keyMaterial); @@ -715,12 +697,17 @@ void SrtCaller::handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sock } onHandShakeFinished(); - return; + return; } void SrtCaller::handleACK(uint8_t *buf, int len, struct sockaddr *addr) { // TraceL; //Acknowledgement of Acknowledgement (ACKACK) control packets are sent to acknowledge the reception of a Full ACK + + if (!_is_handleshake_finished) { + return; + } + ACKPacket ack; if (!ack.loadFromData(buf, len)) { return; @@ -730,7 +717,9 @@ void SrtCaller::handleACK(uint8_t *buf, int len, struct sockaddr *addr) { pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp); pkt->ack_number = ack.ack_number; pkt->storeToData(); - _send_buf->drop(ack.last_ack_pkt_seq_number); + if (_send_buf) { + _send_buf->drop(ack.last_ack_pkt_seq_number); + } sendControlPacket(pkt, true); // TraceL<<"ack number "<(); pkt->loadFromData(buf, len); @@ -774,6 +767,15 @@ void SrtCaller::handleACKACK(uint8_t *buf, int len, struct sockaddr *addr) { } void SrtCaller::handleNAK(uint8_t *buf, int len, struct sockaddr *addr) { + if (!_is_handleshake_finished) { + return; + } + + if (isPlayer()) { + //player should not handle nak + return; + } + //TraceL; NAKPacket pkt; pkt.loadFromData(buf, len); @@ -800,6 +802,15 @@ void SrtCaller::handleNAK(uint8_t *buf, int len, struct sockaddr *addr) { } void SrtCaller::handleDropReq(uint8_t *buf, int len, struct sockaddr *addr) { + if (!_is_handleshake_finished) { + return; + } + + if (!isPlayer()) { + //pusher should not handle drop req + return; + } + MsgDropReqPacket pkt; pkt.loadFromData(buf, len); std::list list; @@ -892,9 +903,9 @@ void SrtCaller::handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockadd } void SrtCaller::handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr) { - //TraceL; - DataPacket::Ptr pkt = std::make_shared(); - pkt->loadFromData(buf, len); + //TraceL; + DataPacket::Ptr pkt = std::make_shared(); + pkt->loadFromData(buf, len); if (_crypto) { auto payload = _crypto->decrypt(pkt, pkt->payloadData(), pkt->payloadSize()); @@ -906,10 +917,10 @@ void SrtCaller::handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr) { pkt->reloadPayload((uint8_t*)payload->data(), payload->size()); } - _estimated_link_capacity_context->inputPacket(_now, pkt); + _estimated_link_capacity_context->inputPacket(_now, pkt); - std::list list; - _recv_buf->inputPacket(pkt, list); + std::list list; + _recv_buf->inputPacket(pkt, list); for (auto& data : list) { if (_last_pkt_seq + 1 != data->packet_seq_number) { TraceL << "pkt lost " << _last_pkt_seq + 1 << "->" << data->packet_seq_number; @@ -1002,20 +1013,13 @@ float SrtCaller::getTimeOutSec() { GET_CONFIG(uint32_t, timeout, SRT::kTimeOutSec); if (timeout <= 0) { WarnL << "config srt " << kTimeOutSec << " not vaild"; - return 5 * 1000; + return 5.0f; } - return (float)timeout * (float)1000; + return (float)timeout; }; std::string SrtCaller::generateStreamId() { - auto streamId = "#!::r=" + _url._app + "/" + _url._stream; - if (_url._vhost != DEFAULT_VHOST) { - streamId += ",h=" +_url._vhost; - } - if (!isPlayer()) { - streamId += ",m=publish"; - } - return streamId; + return _url._streamid; }; uint32_t SrtCaller::generateSocketId() { @@ -1042,6 +1046,21 @@ size_t SrtCaller::getPayloadSize() { return ret; } +size_t SrtCaller::getRecvSpeed() const { + return _socket ? _socket->getRecvSpeed() : 0; +} + +size_t SrtCaller::getRecvTotalBytes() const { + return _socket ? _socket->getRecvTotalBytes() : 0; +} + +size_t SrtCaller::getSendSpeed() const { + return _socket ? _socket->getSendSpeed() : 0; +} + +size_t SrtCaller::getSendTotalBytes() const { + return _socket ? _socket->getSendTotalBytes() : 0; +} } /* namespace mediakit */ diff --git a/src/Srt/SrtCaller.h b/srt/SrtCaller.h similarity index 77% rename from src/Srt/SrtCaller.h rename to srt/SrtCaller.h index 84aa8089..41145479 100644 --- a/src/Srt/SrtCaller.h +++ b/srt/SrtCaller.h @@ -1,199 +1,205 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_SRTCALLER_H -#define ZLMEDIAKIT_SRTCALLER_H - -//srt -#include "srt/Packet.hpp" -#include "srt/Crypto.hpp" -#include "srt/PacketQueue.hpp" -#include "srt/PacketSendQueue.hpp" -#include "srt/Statistic.hpp" - -#include "Poller/EventPoller.h" -#include "Network/Socket.h" -#include "Poller/Timer.h" -#include "Util/TimeTicker.h" -#include "Common/MultiMediaSourceMuxer.h" -#include "Rtp/Decoder.h" -#include "TS/TSMediaSource.h" -#include -#include - - -namespace mediakit { - -// 解析srt 信令url的工具类 -class SrtUrl { -public: - std::string _full_url; - std::string _params; - std::string _host; - uint16_t _port; - std::string _vhost; - std::string _app; - std::string _stream; - -public: - void parse(const std::string &url); -}; - -// 实现了webrtc代理拉流功能 -class SrtCaller : public std::enable_shared_from_this{ -public: - using Ptr = std::shared_ptr; - - using SteadyClock = std::chrono::steady_clock; - using TimePoint = std::chrono::time_point; - - SrtCaller(const toolkit::EventPoller::Ptr &poller); - virtual ~SrtCaller(); - - const toolkit::EventPoller::Ptr &getPoller() const {return _poller;} - - virtual void inputSockData(uint8_t *buf, int len, struct sockaddr *addr); - virtual void onSendTSData(const SRT::Buffer::Ptr &buffer, bool flush); - -protected: - - virtual void onConnect(); - virtual void onHandShakeFinished(); - virtual void onResult(const toolkit::SockException &ex); - - virtual void onSRTData(SRT::DataPacket::Ptr pkt); - - virtual uint16_t getLatency() = 0; - virtual int getLatencyMul(); - virtual int getPktBufSize(); - virtual float getTimeOutSec(); - - virtual bool isPlayer() = 0; - -private: - void doHandshake(); - - void sendHandshakeInduction(); - void sendHandshakeConclusion(); - void sendACKPacket(); - void sendLightACKPacket(); - void sendNAKPacket(std::list &lost_list); - void sendMsgDropReq(uint32_t first, uint32_t last); - void sendKeepLivePacket(); - void sendShutDown(); - void tryAnnounceKeyMaterial(); - void sendControlPacket(SRT::ControlPacket::Ptr pkt, bool flush = true); - void sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, bool flush = false); - void sendPacket(toolkit::Buffer::Ptr pkt, bool flush); - - void handleHandshake(uint8_t *buf, int len, struct sockaddr *addr); - void handleHandshakeInduction(SRT::HandshakePacket &pkt, struct sockaddr *addr); - void handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sockaddr *addr); - void handleACK(uint8_t *buf, int len, struct sockaddr *addr); - void handleACKACK(uint8_t *buf, int len, struct sockaddr *addr); - void handleNAK(uint8_t *buf, int len, struct sockaddr *addr); - void handleDropReq(uint8_t *buf, int len, struct sockaddr *addr); - void handleKeeplive(uint8_t *buf, int len, struct sockaddr *addr); - void handleShutDown(uint8_t *buf, int len, struct sockaddr *addr); - void handlePeerError(uint8_t *buf, int len, struct sockaddr *addr); - void handleCongestionWarning(uint8_t *buf, int len, struct sockaddr *addr); - void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr *addr); - void handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr); - void handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr *addr); - void handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr *addr); - - void checkAndSendAckNak(); - void createTimerForCheckAlive(); - - std::string generateStreamId(); - uint32_t generateSocketId(); - int32_t generateInitSeq(); - size_t getPayloadSize(); - - virtual std::string getPassphrase() = 0; - -protected: - SrtUrl _url; - toolkit::EventPoller::Ptr _poller; - - bool _is_handleshake_finished = false; - -private: - toolkit::Socket::Ptr _socket; - - TimePoint _now; - TimePoint _start_timestamp; - // for calculate rtt for delay - TimePoint _induction_ts; - - //the initial value of RTT is 100 milliseconds - //RTTVar is 50 milliseconds - uint32_t _rtt = 100 * 1000; - uint32_t _rtt_variance = 50 * 1000; - - //local - uint32_t _socket_id = 0; - uint32_t _init_seq_number = 0; - uint32_t _mtu = 1500; - uint32_t _max_flow_window_size = 8192; - uint16_t _delay = 120; - - //peer - uint32_t _sync_cookie = 0; - uint32_t _peer_socket_id; - - // for handshake - SRT::Timer::Ptr _handleshake_timer; - SRT::HandshakePacket::Ptr _handleshake_req; - - // for keeplive - SRT::Ticker _send_ticker; - SRT::Timer::Ptr _keeplive_timer; - - // for alive - SRT::Ticker _alive_ticker; - SRT::Timer::Ptr _alive_timer; - - // for recv - SRT::PacketQueueInterface::Ptr _recv_buf; - uint32_t _last_pkt_seq = 0; - - // Ack - SRT::UTicker _ack_ticker; - uint32_t _last_ack_pkt_seq = 0; - uint32_t _light_ack_pkt_count = 0; - uint32_t _ack_number_count = 0; - std::map _ack_send_timestamp; - // Full Ack - // Link Capacity and Receiving Rate Estimation - std::shared_ptr _pkt_recv_rate_context; - std::shared_ptr _estimated_link_capacity_context; - - // Nak - SRT::UTicker _nak_ticker; - - //for Send - SRT::PacketSendQueue::Ptr _send_buf; - SRT::ResourcePool _packet_pool; - uint32_t _send_packet_seq_number = 0; - uint32_t _send_msg_number = 1; - - //AckAck - uint32_t _last_recv_ackack_seq_num = 0; - - // for encryption - SRT::Crypto::Ptr _crypto; - SRT::Timer::Ptr _announce_timer; - SRT::KeyMaterialPacket::Ptr _announce_req; -}; - -} /* namespace mediakit */ -#endif /* ZLMEDIAKIT_SRTCALLER_H */ - +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_SRTCALLER_H +#define ZLMEDIAKIT_SRTCALLER_H + +//srt +#include "srt/Packet.hpp" +#include "srt/Crypto.hpp" +#include "srt/PacketQueue.hpp" +#include "srt/PacketSendQueue.hpp" +#include "srt/Statistic.hpp" + +#include "Poller/EventPoller.h" +#include "Network/Socket.h" +#include "Poller/Timer.h" +#include "Util/TimeTicker.h" +#include "Common/MultiMediaSourceMuxer.h" +#include "Rtp/Decoder.h" +#include "TS/TSMediaSource.h" +#include +#include + + +namespace mediakit { + +// 解析srt 信令url的工具类 +class SrtUrl { +public: + void parse(const std::string &url); + +public: + std::string _full_url; + std::string _params; + std::string _streamid; + sockaddr_storage _addr; + +private: + uint16_t _port; + std::string _host; +}; + +// 实现了webrtc代理拉流功能 +class SrtCaller : public std::enable_shared_from_this{ +public: + using Ptr = std::shared_ptr; + + using SteadyClock = std::chrono::steady_clock; + using TimePoint = std::chrono::time_point; + + SrtCaller(const toolkit::EventPoller::Ptr &poller); + virtual ~SrtCaller(); + + const toolkit::EventPoller::Ptr &getPoller() const {return _poller;} + + virtual void inputSockData(uint8_t *buf, int len, struct sockaddr *addr); + virtual void onSendTSData(const SRT::Buffer::Ptr &buffer, bool flush); + + size_t getRecvSpeed() const; + size_t getRecvTotalBytes() const; + size_t getSendSpeed() const; + size_t getSendTotalBytes() const; + +protected: + + virtual void onConnect(); + virtual void onHandShakeFinished(); + virtual void onResult(const toolkit::SockException &ex); + + virtual void onSRTData(SRT::DataPacket::Ptr pkt); + + virtual uint16_t getLatency() = 0; + virtual int getLatencyMul(); + virtual int getPktBufSize(); + virtual float getTimeOutSec(); + + virtual bool isPlayer() = 0; + +private: + void doHandshake(); + + void sendHandshakeInduction(); + void sendHandshakeConclusion(); + void sendACKPacket(); + void sendLightACKPacket(); + void sendNAKPacket(std::list &lost_list); + void sendMsgDropReq(uint32_t first, uint32_t last); + void sendKeepLivePacket(); + void sendShutDown(); + void tryAnnounceKeyMaterial(); + void sendControlPacket(SRT::ControlPacket::Ptr pkt, bool flush = true); + void sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, bool flush = false); + void sendPacket(toolkit::Buffer::Ptr pkt, bool flush); + + void handleHandshake(uint8_t *buf, int len, struct sockaddr *addr); + void handleHandshakeInduction(SRT::HandshakePacket &pkt, struct sockaddr *addr); + void handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sockaddr *addr); + void handleACK(uint8_t *buf, int len, struct sockaddr *addr); + void handleACKACK(uint8_t *buf, int len, struct sockaddr *addr); + void handleNAK(uint8_t *buf, int len, struct sockaddr *addr); + void handleDropReq(uint8_t *buf, int len, struct sockaddr *addr); + void handleKeeplive(uint8_t *buf, int len, struct sockaddr *addr); + void handleShutDown(uint8_t *buf, int len, struct sockaddr *addr); + void handlePeerError(uint8_t *buf, int len, struct sockaddr *addr); + void handleCongestionWarning(uint8_t *buf, int len, struct sockaddr *addr); + void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr *addr); + void handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr); + void handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr *addr); + void handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr *addr); + + void checkAndSendAckNak(); + void createTimerForCheckAlive(); + + std::string generateStreamId(); + uint32_t generateSocketId(); + int32_t generateInitSeq(); + size_t getPayloadSize(); + + virtual std::string getPassphrase() = 0; + +protected: + SrtUrl _url; + toolkit::EventPoller::Ptr _poller; + + bool _is_handleshake_finished = false; + +private: + toolkit::Socket::Ptr _socket; + + TimePoint _now; + TimePoint _start_timestamp; + // for calculate rtt for delay + TimePoint _induction_ts; + + //the initial value of RTT is 100 milliseconds + //RTTVar is 50 milliseconds + uint32_t _rtt = 100 * 1000; + uint32_t _rtt_variance = 50 * 1000; + + //local + uint32_t _socket_id = 0; + uint32_t _init_seq_number = 0; + uint32_t _mtu = 1500; + uint32_t _max_flow_window_size = 8192; + uint16_t _delay = 120; + + //peer + uint32_t _sync_cookie = 0; + uint32_t _peer_socket_id; + + // for handshake + SRT::Timer::Ptr _handleshake_timer; + SRT::HandshakePacket::Ptr _handleshake_req; + + // for keeplive + SRT::Ticker _send_ticker; + SRT::Timer::Ptr _keeplive_timer; + + // for alive + SRT::Ticker _alive_ticker; + SRT::Timer::Ptr _alive_timer; + + // for recv + SRT::PacketQueueInterface::Ptr _recv_buf; + uint32_t _last_pkt_seq = 0; + + // Ack + SRT::UTicker _ack_ticker; + uint32_t _last_ack_pkt_seq = 0; + uint32_t _light_ack_pkt_count = 0; + uint32_t _ack_number_count = 0; + std::map _ack_send_timestamp; + // Full Ack + // Link Capacity and Receiving Rate Estimation + std::shared_ptr _pkt_recv_rate_context; + std::shared_ptr _estimated_link_capacity_context; + + // Nak + SRT::UTicker _nak_ticker; + + //for Send + SRT::PacketSendQueue::Ptr _send_buf; + SRT::ResourcePool _packet_pool; + uint32_t _send_packet_seq_number = 0; + uint32_t _send_msg_number = 1; + + //AckAck + uint32_t _last_recv_ackack_seq_num = 0; + + // for encryption + SRT::Crypto::Ptr _crypto; + SRT::Timer::Ptr _announce_timer; + SRT::KeyMaterialPacket::Ptr _announce_req; +}; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_SRTCALLER_H */ + diff --git a/src/Srt/SrtPlayer.cpp b/srt/SrtPlayer.cpp similarity index 82% rename from src/Srt/SrtPlayer.cpp rename to srt/SrtPlayer.cpp index ea20cd77..02d3fc83 100644 --- a/src/Srt/SrtPlayer.cpp +++ b/srt/SrtPlayer.cpp @@ -1,169 +1,184 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#include "SrtPlayer.h" -#include "SrtPlayerImp.h" -#include "Common/config.h" -#include "Http/HlsPlayer.h" - -using namespace toolkit; -using namespace std; - -namespace mediakit { - - -SrtPlayer::SrtPlayer(const EventPoller::Ptr &poller) - : SrtCaller(poller) { - DebugL; -} - -SrtPlayer::~SrtPlayer(void) { - DebugL; -} - -void SrtPlayer::play(const string &strUrl) { - DebugL; - try { - _url.parse(strUrl); - } catch (std::exception &ex) { - onResult(SockException(Err_other, StrPrinter << "illegal srt url:" << ex.what())); - return; - } - onConnect(); - return; -} - -void SrtPlayer::teardown() { - SrtCaller::onResult(SockException(Err_other, StrPrinter << "teardown: " << _url._full_url)); -} - -void SrtPlayer::pause(bool bPause) { - DebugL; -} - -void SrtPlayer::speed(float speed) { - DebugL; -} - -void SrtPlayer::onHandShakeFinished() { - SrtCaller::onHandShakeFinished(); - onResult(SockException(Err_success, "srt play success")); -} - -void SrtPlayer::onResult(const SockException &ex) { - SrtCaller::onResult(ex); - - if (!ex) { - // 播放成功 - onPlayResult(ex); - _benchmark_mode = (*this)[Client::kBenchmarkMode].as(); - - // 播放成功,恢复数据包接收超时定时器 - _recv_ticker.resetTime(); - auto timeout = getTimeOutSec(); - //读取配置文件 - weak_ptr weakSelf = static_pointer_cast(shared_from_this()); - // 创建rtp数据接收超时检测定时器 - _check_timer = std::make_shared(timeout /2, - [weakSelf, timeout]() { - auto strongSelf = weakSelf.lock(); - if (!strongSelf) { - return false; - } - if (strongSelf->_recv_ticker.elapsedTime() > timeout * 1000) { - // 接收媒体数据包超时 - strongSelf->onResult(SockException(Err_timeout, "receive srt media data timeout:" + strongSelf->_url._full_url)); - return false; - } - - return true; - }, getPoller()); - } else { - WarnL << ex.getErrCode() << " " << ex.what(); - if (ex.getErrCode() == Err_shutdown) { - // 主动shutdown的,不触发回调 - return; - } - if (!_is_handleshake_finished) { - onPlayResult(ex); - } else { - onShutdown(ex); - } - } - return; -} - - -void SrtPlayer::onSRTData(SRT::DataPacket::Ptr pkt) { - _recv_ticker.resetTime(); -} - -uint16_t SrtPlayer::getLatency() { - auto latency = (*this)[Client::kLatency].as(); - return (uint16_t)latency ; -} - -float SrtPlayer::getTimeOutSec() { - auto timeoutMS = (*this)[Client::kTimeoutMS].as(); - return (float)timeoutMS / (float)1000; -} - -std::string SrtPlayer::getPassphrase() { - auto passPhrase = (*this)[Client::kPassPhrase].as(); - return passPhrase; -} - -/////////////////////////////////////////////////// -// SrtPlayerImp - -void SrtPlayerImp::onPlayResult(const toolkit::SockException &ex) { - if (ex) { - Super::onPlayResult(ex); - } - //success result only occur when addTrackCompleted - return; -} - -std::vector SrtPlayerImp::getTracks(bool ready /*= true*/) const { - return _demuxer ? static_pointer_cast(_demuxer)->getTracks(ready) : Super::getTracks(ready); -} - -void SrtPlayerImp::addTrackCompleted() { - Super::onPlayResult(toolkit::SockException(toolkit::Err_success, "play success")); -} - -void SrtPlayerImp::onSRTData(SRT::DataPacket::Ptr pkt) { - SrtPlayer::onSRTData(pkt); - - if (_benchmark_mode) { - return; - } - - auto strong_self = shared_from_this(); - if (!_demuxer) { - auto demuxer = std::make_shared(); - demuxer->start(getPoller(), this); - _demuxer = std::move(demuxer); - } - - if (!_decoder && _demuxer) { - _decoder = DecoderImp::createDecoder(DecoderImp::decoder_ts, _demuxer.get()); - } - - if (_decoder && _demuxer) { - _decoder->input(reinterpret_cast(pkt->payloadData()), pkt->payloadSize()); - } - - return; -} - - -} /* namespace mediakit */ - +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "SrtPlayer.h" +#include "SrtPlayerImp.h" +#include "Common/config.h" +#include "Http/HlsPlayer.h" + +using namespace toolkit; +using namespace std; + +namespace mediakit { + + +SrtPlayer::SrtPlayer(const EventPoller::Ptr &poller) + : SrtCaller(poller) { + DebugL; +} + +SrtPlayer::~SrtPlayer(void) { + DebugL; +} + +void SrtPlayer::play(const string &strUrl) { + DebugL; + try { + _url.parse(strUrl); + } catch (std::exception &ex) { + onResult(SockException(Err_other, StrPrinter << "illegal srt url:" << ex.what())); + return; + } + + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->onConnect(); + }); + return; +} + +void SrtPlayer::teardown() { + SrtCaller::onResult(SockException(Err_other, StrPrinter << "teardown: " << _url._full_url)); +} + +void SrtPlayer::pause(bool bPause) { + DebugL; +} + +void SrtPlayer::speed(float speed) { + DebugL; +} + +void SrtPlayer::onHandShakeFinished() { + SrtCaller::onHandShakeFinished(); + onResult(SockException(Err_success, "srt play success")); +} + +void SrtPlayer::onResult(const SockException &ex) { + SrtCaller::onResult(ex); + + if (!ex) { + // 播放成功 + onPlayResult(ex); + _benchmark_mode = (*this)[Client::kBenchmarkMode].as(); + + // 播放成功,恢复数据包接收超时定时器 + _recv_ticker.resetTime(); + auto timeout = getTimeOutSec(); + //读取配置文件 + weak_ptr weakSelf = static_pointer_cast(shared_from_this()); + // 创建rtp数据接收超时检测定时器 + _check_timer = std::make_shared(timeout /2, + [weakSelf, timeout]() { + auto strongSelf = weakSelf.lock(); + if (!strongSelf) { + return false; + } + if (strongSelf->_recv_ticker.elapsedTime() > timeout * 1000) { + // 接收媒体数据包超时 + strongSelf->onResult(SockException(Err_timeout, "receive srt media data timeout:" + strongSelf->_url._full_url)); + return false; + } + + return true; + }, getPoller()); + } else { + WarnL << ex.getErrCode() << " " << ex.what(); + if (ex.getErrCode() == Err_shutdown) { + // 主动shutdown的,不触发回调 + return; + } + if (!_is_handleshake_finished) { + onPlayResult(ex); + } else { + onShutdown(ex); + } + } + return; +} + + +void SrtPlayer::onSRTData(SRT::DataPacket::Ptr pkt) { + _recv_ticker.resetTime(); +} + +uint16_t SrtPlayer::getLatency() { + auto latency = (*this)[Client::kLatency].as(); + return (uint16_t)latency ; +} + +float SrtPlayer::getTimeOutSec() { + auto timeoutMS = (*this)[Client::kTimeoutMS].as(); + return (float)timeoutMS / (float)1000; +} + +std::string SrtPlayer::getPassphrase() { + auto passPhrase = (*this)[Client::kPassPhrase].as(); + return passPhrase; +} + +size_t SrtPlayer::getRecvSpeed() { + return SrtCaller::getRecvSpeed(); +} + +size_t SrtPlayer::getRecvTotalBytes() { + return SrtCaller::getRecvTotalBytes(); +} + +/////////////////////////////////////////////////// +// SrtPlayerImp + +void SrtPlayerImp::onPlayResult(const toolkit::SockException &ex) { + if (ex) { + Super::onPlayResult(ex); + } + //success result only occur when addTrackCompleted + return; +} + +std::vector SrtPlayerImp::getTracks(bool ready /*= true*/) const { + return _demuxer ? static_pointer_cast(_demuxer)->getTracks(ready) : Super::getTracks(ready); +} + +void SrtPlayerImp::addTrackCompleted() { + Super::onPlayResult(toolkit::SockException(toolkit::Err_success, "play success")); +} + +void SrtPlayerImp::onSRTData(SRT::DataPacket::Ptr pkt) { + SrtPlayer::onSRTData(pkt); + + if (_benchmark_mode) { + return; + } + + auto strong_self = shared_from_this(); + if (!_demuxer) { + auto demuxer = std::make_shared(); + demuxer->start(getPoller(), this); + _demuxer = std::move(demuxer); + } + + if (!_decoder && _demuxer) { + _decoder = DecoderImp::createDecoder(DecoderImp::decoder_ts, _demuxer.get()); + } + + if (_decoder && _demuxer) { + _decoder->input(reinterpret_cast(pkt->payloadData()), pkt->payloadSize()); + } + + return; +} + +} /* namespace mediakit */ + diff --git a/src/Srt/SrtPlayer.h b/srt/SrtPlayer.h similarity index 86% rename from src/Srt/SrtPlayer.h rename to srt/SrtPlayer.h index 23d206ca..72d1083c 100644 --- a/src/Srt/SrtPlayer.h +++ b/srt/SrtPlayer.h @@ -1,65 +1,67 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_SRTPLAYER_H -#define ZLMEDIAKIT_SRTPLAYER_H - -#include "Network/Socket.h" -#include "Player/PlayerBase.h" -#include "Poller/Timer.h" -#include "Util/TimeTicker.h" -#include "srt/SrtTransport.hpp" -#include "Http/HttpRequester.h" -#include -#include -#include "SrtCaller.h" - -namespace mediakit { - - -// 实现了srt代理拉流功能 -class SrtPlayer - : public PlayerBase , public SrtCaller { -public: - using Ptr = std::shared_ptr; - - SrtPlayer(const toolkit::EventPoller::Ptr &poller); - ~SrtPlayer() override; - - //// PlayerBase override//// - void play(const std::string &strUrl) override; - void teardown() override; - void pause(bool pause) override; - void speed(float speed) override; - -protected: - - //// SrtCaller override//// - void onHandShakeFinished() override; - void onSRTData(SRT::DataPacket::Ptr pkt) override; - void onResult(const toolkit::SockException &ex) override; - - bool isPlayer() override {return true;} - - uint16_t getLatency() override; - float getTimeOutSec() override; - std::string getPassphrase() override; - -protected: - //是否为性能测试模式 - bool _benchmark_mode = false; - - //超时功能实现 - toolkit::Ticker _recv_ticker; - std::shared_ptr _check_timer; -}; - -} /* namespace mediakit */ -#endif /* ZLMEDIAKIT_SRTPLAYER_H */ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_SRTPLAYER_H +#define ZLMEDIAKIT_SRTPLAYER_H + +#include "Network/Socket.h" +#include "Player/PlayerBase.h" +#include "Poller/Timer.h" +#include "Util/TimeTicker.h" +#include "srt/SrtTransport.hpp" +#include "Http/HttpRequester.h" +#include +#include +#include "SrtCaller.h" + +namespace mediakit { + + +// 实现了srt代理拉流功能 +class SrtPlayer + : public PlayerBase , public SrtCaller { +public: + using Ptr = std::shared_ptr; + + SrtPlayer(const toolkit::EventPoller::Ptr &poller); + ~SrtPlayer() override; + + //// PlayerBase override//// + void play(const std::string &strUrl) override; + void teardown() override; + void pause(bool pause) override; + void speed(float speed) override; + size_t getRecvSpeed() override; + size_t getRecvTotalBytes() override; + +protected: + + //// SrtCaller override//// + void onHandShakeFinished() override; + void onSRTData(SRT::DataPacket::Ptr pkt) override; + void onResult(const toolkit::SockException &ex) override; + + bool isPlayer() override {return true;} + + uint16_t getLatency() override; + float getTimeOutSec() override; + std::string getPassphrase() override; + +protected: + //是否为性能测试模式 + bool _benchmark_mode = false; + + //超时功能实现 + toolkit::Ticker _recv_ticker; + std::shared_ptr _check_timer; +}; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_SRTPLAYER_H */ diff --git a/src/Srt/SrtPlayerImp.h b/srt/SrtPlayerImp.h similarity index 93% rename from src/Srt/SrtPlayerImp.h rename to srt/SrtPlayerImp.h index 0828fe5e..620612b6 100644 --- a/src/Srt/SrtPlayerImp.h +++ b/srt/SrtPlayerImp.h @@ -1,51 +1,51 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_SRtPLAYERIMP_H -#define ZLMEDIAKIT_SRtPLAYERIMP_H - -#include "SrtPlayer.h" - -namespace mediakit { - -class SrtPlayerImp - : public PlayerImp - , private TrackListener { -public: - using Ptr = std::shared_ptr; - using Super = PlayerImp; - - SrtPlayerImp(const toolkit::EventPoller::Ptr &poller) : Super(poller) {} - ~SrtPlayerImp() override { DebugL; } - -private: - //// SrtPlayer override//// - void onSRTData(SRT::DataPacket::Ptr pkt) override; - - //// PlayerBase override//// - void onPlayResult(const toolkit::SockException &ex) override; - std::vector getTracks(bool ready = true) const override; - -private: - //// TrackListener override//// - bool addTrack(const Track::Ptr &track) override { return true; } - void addTrackCompleted() override; - -private: - // for player - DecoderImp::Ptr _decoder; - MediaSinkInterface::Ptr _demuxer; - - // for pusher - TSMediaSource::RingType::RingReader::Ptr _ts_reader; -}; - -} /* namespace mediakit */ -#endif /* ZLMEDIAKIT_SRtPLAYERIMP_H */ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_SRtPLAYERIMP_H +#define ZLMEDIAKIT_SRtPLAYERIMP_H + +#include "SrtPlayer.h" + +namespace mediakit { + +class SrtPlayerImp + : public PlayerImp + , private TrackListener { +public: + using Ptr = std::shared_ptr; + using Super = PlayerImp; + + SrtPlayerImp(const toolkit::EventPoller::Ptr &poller) : Super(poller) {} + ~SrtPlayerImp() override { DebugL; } + +private: + //// SrtPlayer override//// + void onSRTData(SRT::DataPacket::Ptr pkt) override; + + //// PlayerBase override//// + void onPlayResult(const toolkit::SockException &ex) override; + std::vector getTracks(bool ready = true) const override; + +private: + //// TrackListener override//// + bool addTrack(const Track::Ptr &track) override { return true; } + void addTrackCompleted() override; + +private: + // for player + DecoderImp::Ptr _decoder; + MediaSinkInterface::Ptr _demuxer; + + // for pusher + TSMediaSource::RingType::RingReader::Ptr _ts_reader; +}; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_SRtPLAYERIMP_H */ diff --git a/src/Srt/SrtPusher.cpp b/srt/SrtPusher.cpp similarity index 83% rename from src/Srt/SrtPusher.cpp rename to srt/SrtPusher.cpp index 73e2e501..113a2c04 100644 --- a/src/Srt/SrtPusher.cpp +++ b/srt/SrtPusher.cpp @@ -1,116 +1,132 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#include "SrtPusher.h" -#include "Common/config.h" - -using namespace toolkit; -using namespace std; -namespace mediakit { - -SrtPusher::SrtPusher(const EventPoller::Ptr &poller, const TSMediaSource::Ptr &src) : SrtCaller(poller) { - _push_src = src; - DebugL; -} - -SrtPusher::~SrtPusher(void) { - DebugL; -} - -void SrtPusher::publish(const string &strUrl) { - DebugL; - try { - _url.parse(strUrl); - } catch (std::exception &ex) { - onResult(SockException(Err_other, StrPrinter << "illegal srt url:" << ex.what())); - return; - } - onConnect(); - return; -} - -void SrtPusher::teardown() { - SrtCaller::onResult(SockException(Err_other, StrPrinter << "teardown: " << _url._full_url)); -} - -void SrtPusher::onHandShakeFinished() { - SrtCaller::onHandShakeFinished(); - onResult(SockException(Err_success, "srt push success")); - doPublish(); -} - -void SrtPusher::onResult(const SockException &ex) { - SrtCaller::onResult(ex); - - if (!ex) { - onPublishResult(ex); - } else { - WarnL << ex.getErrCode() << " " << ex.what(); - if (ex.getErrCode() == Err_shutdown) { - // 主动shutdown的,不触发回调 - return; - } - if (!_is_handleshake_finished) { - onPublishResult(ex); - } else { - onShutdown(ex); - } - } - return; -} - -uint16_t SrtPusher::getLatency() { - auto latency = (*this)[Client::kLatency].as(); - return (uint16_t)latency ; -} - -float SrtPusher::getTimeOutSec() { - auto timeoutMS = (*this)[Client::kTimeoutMS].as(); - return (float)timeoutMS / (float)1000; -} - -std::string SrtPusher::getPassphrase() { - auto passPhrase = (*this)[Client::kPassPhrase].as(); - return passPhrase; -} - -void SrtPusher::doPublish() { - auto src = _push_src.lock(); - if (!src) { - onResult(SockException(Err_eof, "the media source was released")); - return; - } - // 异步查找直播流 - std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); - _ts_reader = src->getRing()->attach(getPoller()); - _ts_reader->setDetachCB([weak_self]() { - auto strong_self = weak_self.lock(); - if (!strong_self) { - // 本对象已经销毁 - return; - } - strong_self->onShutdown(SockException(Err_shutdown)); - }); - _ts_reader->setReadCB([weak_self](const TSMediaSource::RingDataType &ts_list) { - auto strong_self = weak_self.lock(); - if (!strong_self) { - // 本对象已经销毁 - return; - } - size_t i = 0; - auto size = ts_list->size(); - ts_list->for_each([&](const TSPacket::Ptr &ts) { - strong_self->onSendTSData(ts, ++i == size); - }); - }); -} - -} /* namespace mediakit */ - +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "SrtPusher.h" +#include "Common/config.h" + +using namespace toolkit; +using namespace std; +namespace mediakit { + +SrtPusher::SrtPusher(const EventPoller::Ptr &poller, const TSMediaSource::Ptr &src) : SrtCaller(poller) { + _push_src = src; + DebugL; +} + +SrtPusher::~SrtPusher(void) { + DebugL; +} + +void SrtPusher::publish(const string &strUrl) { + DebugL; + try { + _url.parse(strUrl); + } catch (std::exception &ex) { + onResult(SockException(Err_other, StrPrinter << "illegal srt url:" << ex.what())); + return; + } + + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->onConnect(); + }); + return; +} + +void SrtPusher::teardown() { + SrtCaller::onResult(SockException(Err_other, StrPrinter << "teardown: " << _url._full_url)); +} + +void SrtPusher::onHandShakeFinished() { + SrtCaller::onHandShakeFinished(); + onResult(SockException(Err_success, "srt push success")); + doPublish(); +} + +void SrtPusher::onResult(const SockException &ex) { + SrtCaller::onResult(ex); + + if (!ex) { + onPublishResult(ex); + } else { + WarnL << ex.getErrCode() << " " << ex.what(); + if (ex.getErrCode() == Err_shutdown) { + // 主动shutdown的,不触发回调 + return; + } + if (!_is_handleshake_finished) { + onPublishResult(ex); + } else { + onShutdown(ex); + } + } + return; +} + +uint16_t SrtPusher::getLatency() { + auto latency = (*this)[Client::kLatency].as(); + return (uint16_t)latency ; +} + +float SrtPusher::getTimeOutSec() { + auto timeoutMS = (*this)[Client::kTimeoutMS].as(); + return (float)timeoutMS / (float)1000; +} + +std::string SrtPusher::getPassphrase() { + auto passPhrase = (*this)[Client::kPassPhrase].as(); + return passPhrase; +} + +void SrtPusher::doPublish() { + auto src = _push_src.lock(); + if (!src) { + onResult(SockException(Err_eof, "the media source was released")); + return; + } + // 异步查找直播流 + std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); + _ts_reader = src->getRing()->attach(getPoller()); + _ts_reader->setDetachCB([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + // 本对象已经销毁 + return; + } + strong_self->onShutdown(SockException(Err_shutdown)); + }); + _ts_reader->setReadCB([weak_self](const TSMediaSource::RingDataType &ts_list) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + // 本对象已经销毁 + return; + } + size_t i = 0; + auto size = ts_list->size(); + ts_list->for_each([&](const TSPacket::Ptr &ts) { + strong_self->onSendTSData(ts, ++i == size); + }); + }); +} + +size_t SrtPusher::getSendSpeed() { + return SrtCaller::getSendSpeed(); +} + +size_t SrtPusher::getSendTotalBytes() { + return SrtCaller::getSendTotalBytes(); +} + +} /* namespace mediakit */ + diff --git a/src/Srt/SrtPusher.h b/srt/SrtPusher.h similarity index 92% rename from src/Srt/SrtPusher.h rename to srt/SrtPusher.h index 727b59f3..66b19c28 100644 --- a/src/Srt/SrtPusher.h +++ b/srt/SrtPusher.h @@ -1,59 +1,62 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_SRTPUSHER_H -#define ZLMEDIAKIT_SRTPUSHER_H - -#include "Network/Socket.h" -#include "Pusher/PusherBase.h" -#include "Poller/Timer.h" -#include "Util/TimeTicker.h" -#include "srt/SrtTransport.hpp" -#include "Http/HttpRequester.h" -#include -#include -#include "SrtCaller.h" - -namespace mediakit { - -// 实现了srt代理推流功能 -class SrtPusher - : public PusherBase , public SrtCaller { -public: - using Ptr = std::shared_ptr; - - SrtPusher(const toolkit::EventPoller::Ptr &poller,const TSMediaSource::Ptr &src); - ~SrtPusher() override; - - //// PusherBase override//// - void publish(const std::string &url) override; - void teardown() override; - - void doPublish(); -protected: - - //// SrtCaller override//// - void onHandShakeFinished() override; - void onResult(const toolkit::SockException &ex) override; - - bool isPlayer() override {return false;} - uint16_t getLatency() override; - float getTimeOutSec() override; - std::string getPassphrase() override; - -protected: - std::weak_ptr _push_src; - TSMediaSource::RingType::RingReader::Ptr _ts_reader; -}; - -using SrtPusherImp = PusherImp; - -} /* namespace mediakit */ -#endif /* ZLMEDIAKIT_SRTPUSHER_H */ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_SRTPUSHER_H +#define ZLMEDIAKIT_SRTPUSHER_H + +#include "Network/Socket.h" +#include "Pusher/PusherBase.h" +#include "Poller/Timer.h" +#include "Util/TimeTicker.h" +#include "srt/SrtTransport.hpp" +#include "Http/HttpRequester.h" +#include +#include +#include "SrtCaller.h" + +namespace mediakit { + +// 实现了srt代理推流功能 +class SrtPusher + : public PusherBase , public SrtCaller { +public: + using Ptr = std::shared_ptr; + + SrtPusher(const toolkit::EventPoller::Ptr &poller,const TSMediaSource::Ptr &src); + ~SrtPusher() override; + + //// PusherBase override//// + void publish(const std::string &url) override; + void teardown() override; + + void doPublish(); +protected: + + //// SrtCaller override//// + void onHandShakeFinished() override; + void onResult(const toolkit::SockException &ex) override; + + bool isPlayer() override {return false;} + uint16_t getLatency() override; + float getTimeOutSec() override; + std::string getPassphrase() override; + +protected: + std::weak_ptr _push_src; + TSMediaSource::RingType::RingReader::Ptr _ts_reader; + + size_t getSendSpeed() override; + size_t getSendTotalBytes() override; +}; + +using SrtPusherImp = PusherImp; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_SRTPUSHER_H */ diff --git a/srt/SrtSession.hpp b/srt/SrtSession.hpp index 4064534f..33579585 100644 --- a/srt/SrtSession.hpp +++ b/srt/SrtSession.hpp @@ -1,29 +1,29 @@ -#ifndef ZLMEDIAKIT_SRT_SESSION_H -#define ZLMEDIAKIT_SRT_SESSION_H - -#include "Network/Session.h" -#include "SrtTransport.hpp" - -namespace SRT { - -using namespace toolkit; - -class SrtSession : public Session { -public: - SrtSession(const Socket::Ptr &sock); - - void onRecv(const Buffer::Ptr &) override; - void onError(const SockException &err) override; - void onManager() override; - void attachServer(const toolkit::Server &server) override; - static EventPoller::Ptr queryPoller(const Buffer::Ptr &buffer); - -private: - bool _find_transport = true; - Ticker _ticker; - struct sockaddr_storage _peer_addr; - SrtTransport::Ptr _transport; -}; - -} // namespace SRT -#endif // ZLMEDIAKIT_SRT_SESSION_H \ No newline at end of file +#ifndef ZLMEDIAKIT_SRT_SESSION_H +#define ZLMEDIAKIT_SRT_SESSION_H + +#include "Network/Session.h" +#include "SrtTransport.hpp" + +namespace SRT { + +using namespace toolkit; + +class SrtSession : public Session { +public: + SrtSession(const Socket::Ptr &sock); + + void onRecv(const Buffer::Ptr &) override; + void onError(const SockException &err) override; + void onManager() override; + void attachServer(const toolkit::Server &server) override; + static EventPoller::Ptr queryPoller(const Buffer::Ptr &buffer); + +private: + bool _find_transport = true; + Ticker _ticker; + struct sockaddr_storage _peer_addr; + SrtTransport::Ptr _transport; +}; + +} // namespace SRT +#endif // ZLMEDIAKIT_SRT_SESSION_H diff --git a/srt/SrtTransport.cpp b/srt/SrtTransport.cpp index a968dbcd..d81d17b7 100644 --- a/srt/SrtTransport.cpp +++ b/srt/SrtTransport.cpp @@ -400,7 +400,7 @@ void SrtTransport::sendMsgDropReq(uint32_t first, uint32_t last) { } void SrtTransport::tryAnnounceKeyMaterial() { - //TraceL; + //TraceL; if (!_crypto) { return; diff --git a/srt/SrtTransport.hpp b/srt/SrtTransport.hpp index 36edf093..c9511633 100644 --- a/srt/SrtTransport.hpp +++ b/srt/SrtTransport.hpp @@ -169,8 +169,8 @@ private: // for encryption Crypto::Ptr _crypto; - Timer::Ptr _announce_timer; - KeyMaterialPacket::Ptr _announce_req; + Timer::Ptr _announce_timer; + KeyMaterialPacket::Ptr _announce_req; }; class SrtTransportManager { diff --git a/srt/SrtTransportImp.cpp b/srt/SrtTransportImp.cpp index d0323483..8cd927ff 100644 --- a/srt/SrtTransportImp.cpp +++ b/srt/SrtTransportImp.cpp @@ -117,6 +117,7 @@ bool SrtTransportImp::parseStreamid(std::string &streamid) { _media_info.app = app; _media_info.stream = stream_name; + _media_info.full_url = _media_info.getUrl() + "?" + _media_info.params; TraceL << " mediainfo=" << _media_info.shortUrl() << " params=" << _media_info.params; @@ -144,16 +145,8 @@ void SrtTransportImp::onShutdown(const SockException &ex) { } bool SrtTransportImp::close(mediakit::MediaSource &sender) { - std::string err = StrPrinter << "close media: " << sender.getUrl(); - weak_ptr weak_self = static_pointer_cast(shared_from_this()); - getPoller()->async([weak_self, err]() { - auto strong_self = weak_self.lock(); - if (strong_self) { - strong_self->onShutdown(SockException(Err_shutdown, err)); - // 主动关闭推流,那么不延时注销 - strong_self->_muxer = nullptr; - } - }); + onShutdown(SockException(Err_shutdown, "close media: " + sender.getUrl())); + _muxer = nullptr; return true; } @@ -252,7 +245,7 @@ void SrtTransportImp::doPlay() { weak_ptr weak_session = strong_self->getSession(); strong_self->_ts_reader->setGetInfoCB([weak_session]() { Any ret; - ret.set(static_pointer_cast(weak_session.lock())); + ret.set(static_pointer_cast(weak_session.lock())); return ret; }); strong_self->_ts_reader->setDetachCB([weak_self]() { diff --git a/srt/Statistic.cpp b/srt/Statistic.cpp index 139e39b8..3f4f7e52 100644 --- a/srt/Statistic.cpp +++ b/srt/Statistic.cpp @@ -91,7 +91,7 @@ std::string PacketRecvRateContext::dump(){ } printer <<"\r\n"; - return std::move(printer); + return printer; } EstimatedLinkCapacityContext::EstimatedLinkCapacityContext(TimePoint start) : _start(start) { for (size_t i = 0; i < SIZE; i++) { diff --git a/srt/srt.md b/srt/srt.md index c8a1722e..5e806918 100644 --- a/srt/srt.md +++ b/srt/srt.md @@ -5,7 +5,7 @@ - 拉流只支持ts拉流 - 协议实现 [参考](https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html) - 版本支持(>=1.3.0) -- fec与加密没有实现 +- fec没有实现 ## 使用 diff --git a/srt/srt_en.md b/srt/srt_en.md index b30f660d..53ec6842 100644 --- a/srt/srt_en.md +++ b/srt/srt_en.md @@ -5,7 +5,7 @@ - pull stream payload is ts - protocol impliment [reference](https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html) - version support (>=1.3.0) -- fec and encriyped not support +- fec not support ## usage @@ -40,4 +40,9 @@ like: `ffplay -i srt://192.168.1.105:9000?streamid=#!::r=live/test` -- vlc not support ,because can't set stream id [reference](https://github.com/Haivision/srt/issues/1015) \ No newline at end of file +- vlc pull , Tools->Preferences->All->Inputs/Codecs->Access Modules->SRT,then input Stream ID like below: + + `#!::r=live/test` + + then play `srt://192.168.1.105:9000` + diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 69168683..2d913974 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2016-2022 The ZLMediaKit project authors. All Rights Reserved. +# Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/tests/test_http_head.cpp b/tests/test_http_head.cpp new file mode 100644 index 00000000..385bc372 --- /dev/null +++ b/tests/test_http_head.cpp @@ -0,0 +1,27 @@ +#include "Http/HttpRequester.h" + +int main() { + auto requester = std::make_shared(); + requester->setMethod("HEAD"); + + requester->startRequester( + "http://baidu.com", + + [](const toolkit::SockException &ex, const mediakit::Parser &parser) { + if (ex) { + PrintI("HEAD请求失败: %s", ex.what()); + return; + } + + // 检查HTTP状态码 + if (parser.status() != "200") { + PrintI("HEAD请求返回错误状态: %s", parser.status().c_str()); + return; + } + for (auto &header : parser.getHeader()) { + PrintI("key=%s, val=%s", header.first.c_str(), header.second.c_str()); + } + }); + getchar(); + return 0; +} \ No newline at end of file diff --git a/tests/test_ps.cpp b/tests/test_ps.cpp index 5a78b228..62a6b621 100644 --- a/tests/test_ps.cpp +++ b/tests/test_ps.cpp @@ -64,7 +64,6 @@ public: private: MultiMediaSourceMuxer::Ptr _muxer; - uint64_t timeStamp = 0; uint64_t timeStamp_last = 0; }; diff --git a/tests/test_pusherMp4.cpp b/tests/test_pusherMp4.cpp index bbdee643..ae95901d 100644 --- a/tests/test_pusherMp4.cpp +++ b/tests/test_pusherMp4.cpp @@ -42,6 +42,9 @@ int domain(const string &file, const string &url) { // 根据url获取媒体协议类型,注意大小写 [AUTO-TRANSLATED:3cd6622a] // Get the media protocol type based on the URL, note the case auto schema = strToLower(findSubString(url.data(), nullptr, "://").substr(0, 4)); + if (schema == "webr") { + schema = "rtsp"; + } // 只开启推流协议对应的转协议 [AUTO-TRANSLATED:1c4975ae] // Only enable the protocol conversion corresponding to the push protocol diff --git a/tests/test_rtp_pcap.cpp b/tests/test_rtp_pcap.cpp index 1690b6c3..e5648d5a 100644 --- a/tests/test_rtp_pcap.cpp +++ b/tests/test_rtp_pcap.cpp @@ -83,6 +83,10 @@ struct sniff_tcp { #define TH_URG 0x20 #define TH_ECE 0x40 #define TH_CWR 0x80 + +#if defined(TH_FLAGS) +#undef TH_FLAGS +#endif #define TH_FLAGS (TH_FINTH_SYNTH_RSTTH_ACKTH_URGTH_ECETH_CWR) u_short th_win; /* TCP滑动窗口 */ u_short th_sum; /* 头部校验和 */ @@ -154,7 +158,7 @@ static bool loadFile(const char *path, const EventPoller::Ptr &poller) { return false; } auto total_size = std::make_shared(0); - struct pcap_pkthdr header = {0}; + struct pcap_pkthdr header {}; while (true) { const u_char *pkt_buff = pcap_next(handle.get(), &header); if (!pkt_buff) { diff --git a/webrtc/DtlsTransport.cpp b/webrtc/DtlsTransport.cpp index b55906be..2480ce65 100644 --- a/webrtc/DtlsTransport.cpp +++ b/webrtc/DtlsTransport.cpp @@ -1,1442 +1,1445 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#define MS_CLASS "RTC::DtlsTransport" -// #define MS_LOG_DEV_LEVEL 3 - -#include "DtlsTransport.hpp" -#include "logger.h" -#include -#include -#include -#include -#include -#include // std::sprintf(), std::fopen() -#include // std::memcpy(), std::strcmp() -#include "Util/util.h" -#include "Util/SSLBox.h" -#include "Util/SSLUtil.h" - -using namespace std; - -#define LOG_OPENSSL_ERROR(desc) \ - do \ - { \ - if (ERR_peek_error() == 0) \ - MS_ERROR("OpenSSL error [desc:'%s']", desc); \ - else \ - { \ - int64_t err; \ - while ((err = ERR_get_error()) != 0) \ - { \ - MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ - } \ - ERR_clear_error(); \ - } \ - } while (false) - -/* Static methods for OpenSSL callbacks. */ - -inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) -{ - MS_TRACE(); - - // Always valid since DTLS certificates are self-signed. - return 1; -} - -inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) -{ - if (timerUs == 0) - return 100000; - else if (timerUs >= 4000000) - return 4000000; - else - return 2 * timerUs; -} - -namespace RTC -{ - /* Static. */ - - // clang-format off - static constexpr int DtlsMtu{ 1350 }; - // AES-HMAC: http://tools.ietf.org/html/rfc3711 - static constexpr size_t SrtpMasterKeyLength{ 16 }; - static constexpr size_t SrtpMasterSaltLength{ 14 }; - static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; - // AES-GCM: http://tools.ietf.org/html/rfc7714 - static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; - static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; - static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; - static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; - static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; - static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; - // clang-format on - - /* Class variables. */ - // clang-format off - std::map DtlsTransport::string2FingerprintAlgorithm = - { - { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, - { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, - { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, - { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, - { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } - }; - std::map DtlsTransport::fingerprintAlgorithm2String = - { - { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, - { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, - { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, - { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, - { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } - }; - std::map DtlsTransport::string2Role = - { - { "auto", DtlsTransport::Role::AUTO }, - { "client", DtlsTransport::Role::CLIENT }, - { "server", DtlsTransport::Role::SERVER } - }; - std::vector DtlsTransport::srtpCryptoSuites = - { - { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, - { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, - { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, - { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } - }; - // clang-format on - - INSTANCE_IMP(DtlsTransport::DtlsEnvironment); - - /* Class methods. */ - - DtlsTransport::DtlsEnvironment::DtlsEnvironment() - { - MS_TRACE(); - - // Generate a X509 certificate and private key (unless PEM files are provided). - auto ssl = toolkit::SSL_Initor::Instance().getSSLCtx("", true); - if (!ssl || !ReadCertificateAndPrivateKeyFromContext(ssl.get())) { - GenerateCertificateAndPrivateKey(); - } - - // Create a global SSL_CTX. - CreateSslCtx(); - - // Generate certificate fingerprints. - GenerateFingerprints(); - } - - DtlsTransport::DtlsEnvironment::~DtlsEnvironment() - { - MS_TRACE(); - - if (privateKey) - EVP_PKEY_free(privateKey); - if (certificate) - X509_free(certificate); - if (sslCtx) - SSL_CTX_free(sslCtx); - } - - void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() - { - MS_TRACE(); - - int ret{ 0 }; - EC_KEY* ecKey{ nullptr }; - X509_NAME* certName{ nullptr }; - std::string subject = - std::string("mediasoup") + to_string(rand() % 999999 + 100000); - - // Create key with curve. - ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - - if (!ecKey) - { - LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); - - goto error; - } - - EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); - - // NOTE: This can take some time. - ret = EC_KEY_generate_key(ecKey); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); - - goto error; - } - - // Create a private key object. - privateKey = EVP_PKEY_new(); - - if (!privateKey) - { - LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); - - goto error; - } - - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) - ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); - - goto error; - } - - // The EC key now belongs to the private key, so don't clean it up separately. - ecKey = nullptr; - - // Create the X509 certificate. - certificate = X509_new(); - - if (!certificate) - { - LOG_OPENSSL_ERROR("X509_new() failed"); - - goto error; - } - - // Set version 3 (note that 0 means version 1). - X509_set_version(certificate, 2); - - // Set serial number (avoid default 0). - ASN1_INTEGER_set( - X509_get_serialNumber(certificate), - static_cast(rand() % 999999 + 100000)); - - // Set valid period. - X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. - X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. - - // Set the public key for the certificate using the key. - ret = X509_set_pubkey(certificate, privateKey); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); - - goto error; - } - - // Set certificate fields. - certName = X509_get_subject_name(certificate); - - if (!certName) - { - LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); - - goto error; - } - - X509_NAME_add_entry_by_txt( - certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); - X509_NAME_add_entry_by_txt( - certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); - - // It is self-signed so set the issuer name to be the same as the subject. - ret = X509_set_issuer_name(certificate, certName); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); - - goto error; - } - - // Sign the certificate with its own private key. - ret = X509_sign(certificate, privateKey, EVP_sha1()); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("X509_sign() failed"); - - goto error; - } - - return; - - error: - - if (ecKey) - EC_KEY_free(ecKey); - - if (privateKey) - EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. - - if (certificate) - X509_free(certificate); - - MS_THROW_ERROR("DTLS certificate and private key generation failed"); - } - - bool DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromContext(SSL_CTX *ctx) - { - MS_TRACE(); - certificate = SSL_CTX_get0_certificate(ctx); - if (!certificate) { - return false; - } - X509_up_ref(certificate); - - privateKey = SSL_CTX_get0_privatekey(ctx); - if (!privateKey) { - return false; - } - EVP_PKEY_up_ref(privateKey); - InfoL << "Load webrtc dtls certificate: " << toolkit::SSLUtil::getServerName(certificate); - return true; - } - - void DtlsTransport::DtlsEnvironment::CreateSslCtx() - { - MS_TRACE(); - - std::string dtlsSrtpCryptoSuites; - int ret; - - /* Set the global DTLS context. */ - - // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). - sslCtx = SSL_CTX_new(DTLS_method()); - - if (!sslCtx) - { - LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); - - goto error; - } - - ret = SSL_CTX_use_certificate(sslCtx, certificate); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); - - goto error; - } - - ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); - - goto error; - } - - ret = SSL_CTX_check_private_key(sslCtx); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); - - goto error; - } - - // Set options. - SSL_CTX_set_options( - sslCtx, - SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | - SSL_OP_NO_QUERY_MTU); - - // Don't use sessions cache. - SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); - - // Read always as much into the buffer as possible. - // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL - // versions makes this call required. - SSL_CTX_set_read_ahead(sslCtx, 1); - - SSL_CTX_set_verify_depth(sslCtx, 4); - - // Require certificate from peer. - SSL_CTX_set_verify( - sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); - - // Set SSL info callback. - SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ - static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); - }); - // Set ciphers. - ret = SSL_CTX_set_cipher_list( - sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK:!RC4"); - - if (ret == 0) - { - LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); - - goto error; - } - - // Enable ECDH ciphers. - // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters - // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 - // NOTE: https://bugs.ruby-lang.org/issues/12324 - - // For OpenSSL >= 1.0.2. - SSL_CTX_set_ecdh_auto(sslCtx, 1); - - // Set the "use_srtp" DTLS extension. - for (auto it = DtlsTransport::srtpCryptoSuites.begin(); - it != DtlsTransport::srtpCryptoSuites.end(); - ++it) - { - if (it != DtlsTransport::srtpCryptoSuites.begin()) - dtlsSrtpCryptoSuites += ":"; - - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); - dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; - } - - MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); - - // NOTE: This function returns 0 on success. - ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); - - if (ret != 0) - { - MS_ERROR( - "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); - LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); - - goto error; - } - - return; - - error: - - if (sslCtx) - { - SSL_CTX_free(sslCtx); - sslCtx = nullptr; - } - - MS_THROW_ERROR("SSL context creation failed"); - } - - void DtlsTransport::DtlsEnvironment::GenerateFingerprints() - { - MS_TRACE(); - - for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) - { - const std::string& algorithmString = kv.first; - FingerprintAlgorithm algorithm = kv.second; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{ 0 }; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; - - switch (algorithm) - { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; - - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; - - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; - - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; - - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; - - default: - MS_THROW_ERROR("unknown algorithm"); - } - - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - - if (ret == 0) - { - MS_ERROR("X509_digest() failed"); - MS_THROW_ERROR("Fingerprints generation failed"); - } - - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{ 0 }; i < size; ++i) - { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; - - MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); - - // Store it in the vector. - DtlsTransport::Fingerprint fingerprint; - - fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); - fingerprint.value = hexFingerprint; - - localFingerprints.push_back(fingerprint); - } - } - - /* Instance methods. */ - - DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) - { - MS_TRACE(); - env = DtlsEnvironment::Instance().shared_from_this(); - - /* Set SSL. */ - - this->ssl = SSL_new(env->sslCtx); - - if (!this->ssl) - { - LOG_OPENSSL_ERROR("SSL_new() failed"); - - goto error; - } - - // Set this as custom data. - SSL_set_ex_data(this->ssl, 0, static_cast(this)); - - this->sslBioFromNetwork = BIO_new(BIO_s_mem()); - - if (!this->sslBioFromNetwork) - { - LOG_OPENSSL_ERROR("BIO_new() failed"); - - SSL_free(this->ssl); - - goto error; - } - - this->sslBioToNetwork = BIO_new(BIO_s_mem()); - - if (!this->sslBioToNetwork) - { - LOG_OPENSSL_ERROR("BIO_new() failed"); - - BIO_free(this->sslBioFromNetwork); - SSL_free(this->ssl); - - goto error; - } - - SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); - - // Set the MTU so that we don't send packets that are too large with no fragmentation. - SSL_set_mtu(this->ssl, DtlsMtu); - DTLS_set_link_mtu(this->ssl, DtlsMtu); - - // Set callback handler for setting DTLS timer interval. - DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); - - return; - - error: - - // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as - // well. - if (this->sslBioFromNetwork) - BIO_free(this->sslBioFromNetwork); - - if (this->sslBioToNetwork) - BIO_free(this->sslBioToNetwork); - - if (this->ssl) - SSL_free(this->ssl); - - // NOTE: If this is not catched by the caller the program will abort, but - // this should never happen. - MS_THROW_ERROR("DtlsTransport instance creation failed"); - } - - DtlsTransport::~DtlsTransport() - { - MS_TRACE(); - - if (IsRunning()) - { - // Send close alert to the peer. - SSL_shutdown(this->ssl); - SendPendingOutgoingDtlsData(); - } - - if (this->ssl) - { - SSL_free(this->ssl); - - this->ssl = nullptr; - this->sslBioFromNetwork = nullptr; - this->sslBioToNetwork = nullptr; - } - - // Close the DTLS timer. - this->timer = nullptr; - } - - void DtlsTransport::Dump() const - { - MS_TRACE(); - - std::string state{ "new" }; - std::string role{ "none " }; - - switch (this->state) - { - case DtlsState::CONNECTING: - state = "connecting"; - break; - case DtlsState::CONNECTED: - state = "connected"; - break; - case DtlsState::FAILED: - state = "failed"; - break; - case DtlsState::CLOSED: - state = "closed"; - break; - default:; - } - - switch (this->localRole) - { - case Role::AUTO: - role = "auto"; - break; - case Role::SERVER: - role = "server"; - break; - case Role::CLIENT: - role = "client"; - break; - default:; - } - - MS_DUMP(""); - MS_DUMP(" state : %s", state.c_str()); - MS_DUMP(" role : %s", role.c_str()); - MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); - MS_DUMP(""); - } - - void DtlsTransport::Run(Role localRole) - { - MS_TRACE(); - - MS_ASSERT( - localRole == Role::CLIENT || localRole == Role::SERVER, - "local DTLS role must be 'client' or 'server'"); - - Role previousLocalRole = this->localRole; - - if (localRole == previousLocalRole) - { - MS_ERROR("same local DTLS role provided, doing nothing"); - - return; - } - - // If the previous local DTLS role was 'client' or 'server' do reset. - if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) - { - MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); - - Reset(); - } - - // Update local role. - this->localRole = localRole; - - // Set state and notify the listener. - this->state = DtlsState::CONNECTING; - this->listener->OnDtlsTransportConnecting(this); - - switch (this->localRole) - { - case Role::CLIENT: - { - MS_DEBUG_TAG(dtls, "running [role:client]"); - - SSL_set_connect_state(this->ssl); - SSL_do_handshake(this->ssl); - SendPendingOutgoingDtlsData(); - SetTimeout(); - - break; - } - - case Role::SERVER: - { - MS_DEBUG_TAG(dtls, "running [role:server]"); - - SSL_set_accept_state(this->ssl); - SSL_do_handshake(this->ssl); - - break; - } - - default: - { - MS_ABORT("invalid local DTLS role"); - } - } - } - - bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) - { - MS_TRACE(); - - MS_ASSERT( - fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); - - this->remoteFingerprint = fingerprint; - - // The remote fingerpring may have been set after DTLS handshake was done, - // so we may need to process it now. - if (this->handshakeDone && this->state != DtlsState::CONNECTED) - { - MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); - - return ProcessHandshake(); - } - - return true; - } - - void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) - { - MS_TRACE(); - - int written; - int read; - - if (!IsRunning()) - { - MS_WARN_TAG(nullptr,"cannot process data while not running"); - return; - } - - // Write the received DTLS data into the sslBioFromNetwork. - written = - BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); - - if (written != static_cast(len)) - { - MS_WARN_TAG( - dtls, - "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", - static_cast(written), - len); - } - - // Must call SSL_read() to process received DTLS data. - read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); - - // Send data if it's ready. - SendPendingOutgoingDtlsData(); - - // Check SSL status and return if it is bad/closed. - if (!CheckStatus(read)) - return; - - // Set/update the DTLS timeout. - if (!SetTimeout()) - return; - - // Application data received. Notify to the listener. - if (read > 0) - { - // It is allowed to receive DTLS data even before validating remote fingerprint. - if (!this->handshakeDone) - { - MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); - - return; - } - - // Notify the listener. - this->listener->OnDtlsTransportApplicationDataReceived( - this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); - } - } - - void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) - { - MS_TRACE(); - - // We cannot send data to the peer if its remote fingerprint is not validated. - if (this->state != DtlsState::CONNECTED) - { - MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); - - return; - } - - if (len == 0) - { - MS_WARN_TAG(dtls, "ignoring 0 length data"); - - return; - } - - int written; - - written = SSL_write(this->ssl, static_cast(data), static_cast(len)); - - if (written < 0) - { - LOG_OPENSSL_ERROR("SSL_write() failed"); - - if (!CheckStatus(written)) - return; - } - else if (written != static_cast(len)) - { - MS_WARN_TAG( - dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); - } - - // Send data. - SendPendingOutgoingDtlsData(); - } - - void DtlsTransport::Reset() - { - MS_TRACE(); - - int ret; - - if (!IsRunning()) - return; - - MS_WARN_TAG(dtls, "resetting DTLS transport"); - - // Stop the DTLS timer. - this->timer = nullptr; - - // We need to reset the SSL instance so we need to "shutdown" it, but we - // don't want to send a Close Alert to the peer, so just don't call - // SendPendingOutgoingDTLSData(). - SSL_shutdown(this->ssl); - - this->localRole = Role::NONE; - this->state = DtlsState::NEW; - this->handshakeDone = false; - this->handshakeDoneNow = false; - - // Reset SSL status. - // NOTE: For this to properly work, SSL_shutdown() must be called before. - // NOTE: This may fail if not enough DTLS handshake data has been received, - // but we don't care so just clear the error queue. - ret = SSL_clear(this->ssl); - - if (ret == 0) - ERR_clear_error(); - } - - inline bool DtlsTransport::CheckStatus(int returnCode) - { - MS_TRACE(); - - int err; - bool wasHandshakeDone = this->handshakeDone; - - err = SSL_get_error(this->ssl, returnCode); - - switch (err) - { - case SSL_ERROR_NONE: - break; - - case SSL_ERROR_SSL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); - break; - - case SSL_ERROR_WANT_READ: - break; - - case SSL_ERROR_WANT_WRITE: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); - break; - - case SSL_ERROR_WANT_X509_LOOKUP: - MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); - break; - - case SSL_ERROR_SYSCALL: - LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); - break; - - case SSL_ERROR_ZERO_RETURN: - break; - - case SSL_ERROR_WANT_CONNECT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); - break; - - case SSL_ERROR_WANT_ACCEPT: - MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); - break; - - default: - MS_WARN_TAG(dtls, "SSL status: unknown error"); - } - - // Check if the handshake (or re-handshake) has been done right now. - if (this->handshakeDoneNow) - { - this->handshakeDoneNow = false; - this->handshakeDone = true; - - // Stop the timer. - this->timer = nullptr; - - // Process the handshake just once (ignore if DTLS renegotiation). - if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) - return ProcessHandshake(); - - return true; - } - // Check if the peer sent close alert or a fatal error happened. - else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) - { - if (this->state == DtlsState::CONNECTED) - { - MS_DEBUG_TAG(dtls, "disconnected"); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::CLOSED; - this->listener->OnDtlsTransportClosed(this); - } - else - { - MS_WARN_TAG(dtls, "connection failed"); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - } - - return false; - } - else - { - return true; - } - } - - inline void DtlsTransport::SendPendingOutgoingDtlsData() - { - MS_TRACE(); - - if (BIO_eof(this->sslBioToNetwork)) - return; - - int64_t read; - char* data{ nullptr }; - - read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT - - if (read <= 0) - return; - - MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); - - // Notify the listener. - this->listener->OnDtlsTransportSendData( - this, reinterpret_cast(data), static_cast(read)); - - // Clear the BIO buffer. - // NOTE: the (void) avoids the -Wunused-value warning. - (void)BIO_reset(this->sslBioToNetwork); - } - - inline bool DtlsTransport::SetTimeout() - { - MS_TRACE(); - - MS_ASSERT( - this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, - "invalid DTLS state"); - - int64_t ret; - struct timeval dtlsTimeout{ 0, 0 }; - uint64_t timeoutMs; - - // NOTE: If ret == 0 then ignore the value in dtlsTimeout. - // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. - ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT - - if (ret == 0) - return true; - - timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); - - if (timeoutMs == 0) - { - return true; - } - else if (timeoutMs < 30000) - { - MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); - - weak_ptr weak_self = shared_from_this(); - this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ - auto strong_self = weak_self.lock(); - if(strong_self){ - strong_self->OnTimer(); - } - return true; - }, this->poller); - - return true; - } - // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. - else - { - MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - - return false; - } - } - - inline bool DtlsTransport::ProcessHandshake() - { - MS_TRACE(); - - MS_ASSERT(this->handshakeDone, "handshake not done yet"); - MS_ASSERT( - this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); - - // Validate the remote fingerprint. - if (!CheckRemoteFingerprint()) - { - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - - return false; - } - - // Get the negotiated SRTP crypto suite. - RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); - - if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) - { - // Extract the SRTP keys (will notify the listener with them). - ExtractSrtpKeys(srtpCryptoSuite); - - return true; - } - - // NOTE: We assume that "use_srtp" DTLS extension is required even if - // there is no audio/video. - MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); - - Reset(); - - // Set state and notify the listener. - this->state = DtlsState::FAILED; - this->listener->OnDtlsTransportFailed(this); - - return false; - } - - inline bool DtlsTransport::CheckRemoteFingerprint() - { - MS_TRACE(); - - MS_ASSERT( - this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); - - X509* certificate; - uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; - unsigned int size{ 0 }; - char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; - const EVP_MD* hashFunction; - int ret; - - certificate = SSL_get_peer_certificate(this->ssl); - - if (!certificate) - { - MS_WARN_TAG(dtls, "no certificate was provided by the peer"); - - return false; - } - - switch (this->remoteFingerprint.algorithm) - { - case FingerprintAlgorithm::SHA1: - hashFunction = EVP_sha1(); - break; - - case FingerprintAlgorithm::SHA224: - hashFunction = EVP_sha224(); - break; - - case FingerprintAlgorithm::SHA256: - hashFunction = EVP_sha256(); - break; - - case FingerprintAlgorithm::SHA384: - hashFunction = EVP_sha384(); - break; - - case FingerprintAlgorithm::SHA512: - hashFunction = EVP_sha512(); - break; - - default: - MS_ABORT("unknown algorithm"); - } - - // Compare the remote fingerprint with the value given via signaling. - ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); - - if (ret == 0) - { - MS_ERROR("X509_digest() failed"); - - X509_free(certificate); - - return false; - } - - // Convert to hexadecimal format in uppercase with colons. - for (unsigned int i{ 0 }; i < size; ++i) - { - std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); - } - hexFingerprint[(size * 3) - 1] = '\0'; - - if (this->remoteFingerprint.value != hexFingerprint) - { - MS_WARN_TAG( - dtls, - "fingerprint in the remote certificate (%s) does not match the announced one (%s)", - hexFingerprint, - this->remoteFingerprint.value.c_str()); - X509_free(certificate); - return false; - } - - MS_DEBUG_TAG(dtls, "valid remote fingerprint"); - - // Get the remote certificate in PEM format. - - BIO* bio = BIO_new(BIO_s_mem()); - - // Ensure the underlying BUF_MEM structure is also freed. - // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since - // BIO_set_close() always returns 1. - (void)BIO_set_close(bio, BIO_CLOSE); - - ret = PEM_write_bio_X509(bio, certificate); - - if (ret != 1) - { - LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); - - X509_free(certificate); - BIO_free(bio); - - return false; - } - - BUF_MEM* mem; - - BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] - - if (!mem || !mem->data || mem->length == 0u) - { - LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); - - X509_free(certificate); - BIO_free(bio); - - return false; - } - - this->remoteCert = std::string(mem->data, mem->length); - - X509_free(certificate); - BIO_free(bio); - - return true; - } - - inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) - { - MS_TRACE(); - - size_t srtpKeyLength{ 0 }; - size_t srtpSaltLength{ 0 }; - size_t srtpMasterLength{ 0 }; - - switch (srtpCryptoSuite) - { - case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: - case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: - { - srtpKeyLength = SrtpMasterKeyLength; - srtpSaltLength = SrtpMasterSaltLength; - srtpMasterLength = SrtpMasterLength; - - break; - } - - case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: - { - srtpKeyLength = SrtpAesGcm256MasterKeyLength; - srtpSaltLength = SrtpAesGcm256MasterSaltLength; - srtpMasterLength = SrtpAesGcm256MasterLength; - - break; - } - - case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: - { - srtpKeyLength = SrtpAesGcm128MasterKeyLength; - srtpSaltLength = SrtpAesGcm128MasterSaltLength; - srtpMasterLength = SrtpAesGcm128MasterLength; - - break; - } - - default: - { - MS_ABORT("unknown SRTP crypto suite"); - } - } - - auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; - uint8_t* srtpLocalKey{ nullptr }; - uint8_t* srtpLocalSalt{ nullptr }; - uint8_t* srtpRemoteKey{ nullptr }; - uint8_t* srtpRemoteSalt{ nullptr }; - auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; - auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; - int ret; - - ret = SSL_export_keying_material( - this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); - - MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); - - switch (this->localRole) - { - case Role::SERVER: - { - srtpRemoteKey = srtpMaterial; - srtpLocalKey = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; - - break; - } - - case Role::CLIENT: - { - srtpLocalKey = srtpMaterial; - srtpRemoteKey = srtpLocalKey + srtpKeyLength; - srtpLocalSalt = srtpRemoteKey + srtpKeyLength; - srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; - - break; - } - - default: - { - MS_ABORT("no DTLS role set"); - } - } - - // Create the SRTP local master key. - std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); - std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); - // Create the SRTP remote master key. - std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); - std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); - - // Set state and notify the listener. - this->state = DtlsState::CONNECTED; - this->listener->OnDtlsTransportConnected( - this, - srtpCryptoSuite, - srtpLocalMasterKey, - srtpMasterLength, - srtpRemoteMasterKey, - srtpMasterLength, - this->remoteCert); - - delete[] srtpMaterial; - delete[] srtpLocalMasterKey; - delete[] srtpRemoteMasterKey; - } - - inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() - { - MS_TRACE(); - - RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; - - // Ensure that the SRTP crypto suite has been negotiated. - // NOTE: This is a OpenSSL type. - SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); - - if (!sslSrtpCryptoSuite) - return negotiatedSrtpCryptoSuite; - - // Get the negotiated SRTP crypto suite. - for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) - { - SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); - - if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) - { - MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); - - negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; - } - } - - MS_ASSERT( - negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, - "chosen SRTP crypto suite is not an available one"); - - return negotiatedSrtpCryptoSuite; - } - - inline void DtlsTransport::OnSslInfo(int where, int ret) - { - MS_TRACE(); - - int w = where & -SSL_ST_MASK; - const char* role; - - if ((w & SSL_ST_CONNECT) != 0) - role = "client"; - else if ((w & SSL_ST_ACCEPT) != 0) - role = "server"; - else - role = "undefined"; - - if ((where & SSL_CB_LOOP) != 0) - { - MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); - } - else if ((where & SSL_CB_ALERT) != 0) - { - const char* alertType; - - switch (*SSL_alert_type_string(ret)) - { - case 'W': - alertType = "warning"; - break; - - case 'F': - alertType = "fatal"; - break; - - default: - alertType = "undefined"; - } - - if ((where & SSL_CB_READ) != 0) - { - MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - else if ((where & SSL_CB_WRITE) != 0) - { - MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - else - { - MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); - } - } - else if ((where & SSL_CB_EXIT) != 0) - { - if (ret == 0) - MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); - else if (ret < 0) - MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); - } - else if ((where & SSL_CB_HANDSHAKE_START) != 0) - { - MS_DEBUG_TAG(dtls, "DTLS handshake start"); - } - else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) - { - MS_DEBUG_TAG(dtls, "DTLS handshake done"); - - this->handshakeDoneNow = true; - } - - // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon - // receipt of a close alert does not work (the flag is set after this callback). - } - - inline void DtlsTransport::OnTimer() - { - MS_TRACE(); - - // Workaround for https://github.com/openssl/openssl/issues/7998. - if (this->handshakeDone) - { - // MS_DEBUG_DEV("handshake is done so return"); - return; - } - - DTLSv1_handle_timeout(this->ssl); - - // If required, send DTLS data. - SendPendingOutgoingDtlsData(); - - // Set the DTLS timer again. - SetTimeout(); - } -} // namespace RTC +/** +ISC License + +Copyright © 2015, Iñaki Baz Castillo + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#define MS_CLASS "RTC::DtlsTransport" +// #define MS_LOG_DEV_LEVEL 3 + +#include "DtlsTransport.hpp" +#include "logger.h" +#include +#include +#include +#include +#include +#include // std::sprintf(), std::fopen() +#include // std::memcpy(), std::strcmp() +#include "Util/util.h" +#include "Util/SSLBox.h" +#include "Util/SSLUtil.h" + +using namespace std; +using namespace toolkit; + +#define LOG_OPENSSL_ERROR(desc) \ + do \ + { \ + if (ERR_peek_error() == 0) \ + MS_ERROR("OpenSSL error [desc:'%s']", desc); \ + else \ + { \ + int64_t err; \ + while ((err = ERR_get_error()) != 0) \ + { \ + MS_ERROR("OpenSSL error [desc:'%s', error:'%s']", desc, ERR_error_string(err, nullptr)); \ + } \ + ERR_clear_error(); \ + } \ + } while (false) + +/* Static methods for OpenSSL callbacks. */ + +inline static int onSslCertificateVerify(int /*preverifyOk*/, X509_STORE_CTX* /*ctx*/) +{ + MS_TRACE(); + + // Always valid since DTLS certificates are self-signed. + return 1; +} + +inline static unsigned int onSslDtlsTimer(SSL* /*ssl*/, unsigned int timerUs) +{ + if (timerUs == 0) + return 100000; + else if (timerUs >= 4000000) + return 4000000; + else + return 2 * timerUs; +} + +namespace RTC +{ + /* Static. */ + + // clang-format off + static constexpr int DtlsMtu{ 1350 }; + // AES-HMAC: http://tools.ietf.org/html/rfc3711 + static constexpr size_t SrtpMasterKeyLength{ 16 }; + static constexpr size_t SrtpMasterSaltLength{ 14 }; + static constexpr size_t SrtpMasterLength{ SrtpMasterKeyLength + SrtpMasterSaltLength }; + // AES-GCM: http://tools.ietf.org/html/rfc7714 + static constexpr size_t SrtpAesGcm256MasterKeyLength{ 32 }; + static constexpr size_t SrtpAesGcm256MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm256MasterLength{ SrtpAesGcm256MasterKeyLength + SrtpAesGcm256MasterSaltLength }; + static constexpr size_t SrtpAesGcm128MasterKeyLength{ 16 }; + static constexpr size_t SrtpAesGcm128MasterSaltLength{ 12 }; + static constexpr size_t SrtpAesGcm128MasterLength{ SrtpAesGcm128MasterKeyLength + SrtpAesGcm128MasterSaltLength }; + // clang-format on + + /* Class variables. */ + // clang-format off + std::map DtlsTransport::string2FingerprintAlgorithm = + { + { "sha-1", DtlsTransport::FingerprintAlgorithm::SHA1 }, + { "sha-224", DtlsTransport::FingerprintAlgorithm::SHA224 }, + { "sha-256", DtlsTransport::FingerprintAlgorithm::SHA256 }, + { "sha-384", DtlsTransport::FingerprintAlgorithm::SHA384 }, + { "sha-512", DtlsTransport::FingerprintAlgorithm::SHA512 } + }; + std::map DtlsTransport::fingerprintAlgorithm2String = + { + { DtlsTransport::FingerprintAlgorithm::SHA1, "sha-1" }, + { DtlsTransport::FingerprintAlgorithm::SHA224, "sha-224" }, + { DtlsTransport::FingerprintAlgorithm::SHA256, "sha-256" }, + { DtlsTransport::FingerprintAlgorithm::SHA384, "sha-384" }, + { DtlsTransport::FingerprintAlgorithm::SHA512, "sha-512" } + }; + std::map DtlsTransport::string2Role = + { + { "auto", DtlsTransport::Role::AUTO }, + { "client", DtlsTransport::Role::CLIENT }, + { "server", DtlsTransport::Role::SERVER } + }; + std::vector DtlsTransport::srtpCryptoSuites = + { + { RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM, "SRTP_AEAD_AES_256_GCM" }, + { RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM, "SRTP_AEAD_AES_128_GCM" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80, "SRTP_AES128_CM_SHA1_80" }, + { RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32, "SRTP_AES128_CM_SHA1_32" } + }; + // clang-format on + + INSTANCE_IMP(DtlsTransport::DtlsEnvironment); + + /* Class methods. */ + + DtlsTransport::DtlsEnvironment::DtlsEnvironment() + { + MS_TRACE(); + + // Generate a X509 certificate and private key (unless PEM files are provided). + auto ssl = toolkit::SSL_Initor::Instance().getSSLCtx("", true); + if (!ssl || !ReadCertificateAndPrivateKeyFromContext(ssl.get())) { + GenerateCertificateAndPrivateKey(); + } + + // Create a global SSL_CTX. + CreateSslCtx(); + + // Generate certificate fingerprints. + GenerateFingerprints(); + } + + DtlsTransport::DtlsEnvironment::~DtlsEnvironment() + { + MS_TRACE(); + + if (privateKey) + EVP_PKEY_free(privateKey); + if (certificate) + X509_free(certificate); + if (sslCtx) + SSL_CTX_free(sslCtx); + } + + void DtlsTransport::DtlsEnvironment::GenerateCertificateAndPrivateKey() + { + MS_TRACE(); + + int ret{ 0 }; + EC_KEY* ecKey{ nullptr }; + X509_NAME* certName{ nullptr }; + std::string subject = + std::string("mediasoup") + to_string(rand() % 999999 + 100000); + + // Create key with curve. + ecKey = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + + if (!ecKey) + { + LOG_OPENSSL_ERROR("EC_KEY_new_by_curve_name() failed"); + + goto error; + } + + EC_KEY_set_asn1_flag(ecKey, OPENSSL_EC_NAMED_CURVE); + + // NOTE: This can take some time. + ret = EC_KEY_generate_key(ecKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("EC_KEY_generate_key() failed"); + + goto error; + } + + // Create a private key object. + privateKey = EVP_PKEY_new(); + + if (!privateKey) + { + LOG_OPENSSL_ERROR("EVP_PKEY_new() failed"); + + goto error; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast) + ret = EVP_PKEY_assign_EC_KEY(privateKey, ecKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("EVP_PKEY_assign_EC_KEY() failed"); + + goto error; + } + + // The EC key now belongs to the private key, so don't clean it up separately. + ecKey = nullptr; + + // Create the X509 certificate. + certificate = X509_new(); + + if (!certificate) + { + LOG_OPENSSL_ERROR("X509_new() failed"); + + goto error; + } + + // Set version 3 (note that 0 means version 1). + X509_set_version(certificate, 2); + + // Set serial number (avoid default 0). + ASN1_INTEGER_set( + X509_get_serialNumber(certificate), + static_cast(rand() % 999999 + 100000)); + + // Set valid period. + X509_gmtime_adj(X509_get_notBefore(certificate), -315360000); // -10 years. + X509_gmtime_adj(X509_get_notAfter(certificate), 315360000); // 10 years. + + // Set the public key for the certificate using the key. + ret = X509_set_pubkey(certificate, privateKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_pubkey() failed"); + + goto error; + } + + // Set certificate fields. + certName = X509_get_subject_name(certificate); + + if (!certName) + { + LOG_OPENSSL_ERROR("X509_get_subject_name() failed"); + + goto error; + } + + X509_NAME_add_entry_by_txt( + certName, "O", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + X509_NAME_add_entry_by_txt( + certName, "CN", MBSTRING_ASC, reinterpret_cast(subject.c_str()), -1, -1, 0); + + // It is self-signed so set the issuer name to be the same as the subject. + ret = X509_set_issuer_name(certificate, certName); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_set_issuer_name() failed"); + + goto error; + } + + // Sign the certificate with its own private key. + ret = X509_sign(certificate, privateKey, EVP_sha1()); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("X509_sign() failed"); + + goto error; + } + + return; + + error: + + if (ecKey) + EC_KEY_free(ecKey); + + if (privateKey) + EVP_PKEY_free(privateKey); // NOTE: This also frees the EC key. + + if (certificate) + X509_free(certificate); + + MS_THROW_ERROR("DTLS certificate and private key generation failed"); + } + + bool DtlsTransport::DtlsEnvironment::ReadCertificateAndPrivateKeyFromContext(SSL_CTX *ctx) + { + MS_TRACE(); + certificate = SSL_CTX_get0_certificate(ctx); + if (!certificate) { + return false; + } + X509_up_ref(certificate); + + privateKey = SSL_CTX_get0_privatekey(ctx); + if (!privateKey) { + return false; + } + EVP_PKEY_up_ref(privateKey); + InfoL << "Load webrtc dtls certificate: " << toolkit::SSLUtil::getServerName(certificate); + return true; + } + + void DtlsTransport::DtlsEnvironment::CreateSslCtx() + { + MS_TRACE(); + + std::string dtlsSrtpCryptoSuites; + int ret; + + /* Set the global DTLS context. */ + + // Both DTLS 1.0 and 1.2 (requires OpenSSL >= 1.1.0). + sslCtx = SSL_CTX_new(DTLS_method()); + + if (!sslCtx) + { + LOG_OPENSSL_ERROR("SSL_CTX_new() failed"); + + goto error; + } + + ret = SSL_CTX_use_certificate(sslCtx, certificate); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_certificate() failed"); + + goto error; + } + + ret = SSL_CTX_use_PrivateKey(sslCtx, privateKey); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_use_PrivateKey() failed"); + + goto error; + } + + ret = SSL_CTX_check_private_key(sslCtx); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_check_private_key() failed"); + + goto error; + } + + // Set options. + SSL_CTX_set_options( + sslCtx, + SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_TICKET | SSL_OP_SINGLE_ECDH_USE | + SSL_OP_NO_QUERY_MTU); + + // Don't use sessions cache. + SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_OFF); + + // Read always as much into the buffer as possible. + // NOTE: This is the default for DTLS, but a bug in non latest OpenSSL + // versions makes this call required. + SSL_CTX_set_read_ahead(sslCtx, 1); + + SSL_CTX_set_verify_depth(sslCtx, 4); + + // Require certificate from peer. + SSL_CTX_set_verify( + sslCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, onSslCertificateVerify); + + // Set SSL info callback. + SSL_CTX_set_info_callback(sslCtx, [](const SSL* ssl, int where, int ret){ + static_cast(SSL_get_ex_data(ssl, 0))->OnSslInfo(where, ret); + }); + // Set ciphers. + ret = SSL_CTX_set_cipher_list( + sslCtx, "DEFAULT:!NULL:!aNULL:!SHA256:!SHA384:!aECDH:!AESGCM+AES256:!aPSK:!RC4"); + + if (ret == 0) + { + LOG_OPENSSL_ERROR("SSL_CTX_set_cipher_list() failed"); + + goto error; + } + + // Enable ECDH ciphers. + // DOC: http://en.wikibooks.org/wiki/OpenSSL/Diffie-Hellman_parameters + // NOTE: https://code.google.com/p/chromium/issues/detail?id=406458 + // NOTE: https://bugs.ruby-lang.org/issues/12324 + + // For OpenSSL >= 1.0.2. + SSL_CTX_set_ecdh_auto(sslCtx, 1); + + // Set the "use_srtp" DTLS extension. + for (auto it = DtlsTransport::srtpCryptoSuites.begin(); + it != DtlsTransport::srtpCryptoSuites.end(); + ++it) + { + if (it != DtlsTransport::srtpCryptoSuites.begin()) + dtlsSrtpCryptoSuites += ":"; + + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(*it); + dtlsSrtpCryptoSuites += cryptoSuiteEntry->name; + } + + MS_DEBUG_2TAGS(dtls, srtp, "setting SRTP cryptoSuites for DTLS: %s", dtlsSrtpCryptoSuites.c_str()); + + // NOTE: This function returns 0 on success. + ret = SSL_CTX_set_tlsext_use_srtp(sslCtx, dtlsSrtpCryptoSuites.c_str()); + + if (ret != 0) + { + MS_ERROR( + "SSL_CTX_set_tlsext_use_srtp() failed when entering '%s'", dtlsSrtpCryptoSuites.c_str()); + LOG_OPENSSL_ERROR("SSL_CTX_set_tlsext_use_srtp() failed"); + + goto error; + } + + return; + + error: + + if (sslCtx) + { + SSL_CTX_free(sslCtx); + sslCtx = nullptr; + } + + MS_THROW_ERROR("SSL context creation failed"); + } + + void DtlsTransport::DtlsEnvironment::GenerateFingerprints() + { + MS_TRACE(); + + for (auto& kv : DtlsTransport::string2FingerprintAlgorithm) + { + const std::string& algorithmString = kv.first; + FingerprintAlgorithm algorithm = kv.second; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; + + switch (algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; + + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; + + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; + + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; + + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; + + default: + MS_THROW_ERROR("unknown algorithm"); + } + + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); + MS_THROW_ERROR("Fingerprints generation failed"); + } + + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; + + MS_DEBUG_TAG(dtls, "%-7s fingerprint: %s", algorithmString.c_str(), hexFingerprint); + + // Store it in the vector. + DtlsTransport::Fingerprint fingerprint; + + fingerprint.algorithm = DtlsTransport::GetFingerprintAlgorithm(algorithmString); + fingerprint.value = hexFingerprint; + + localFingerprints.push_back(fingerprint); + } + } + + /* Instance methods. */ + + DtlsTransport::DtlsTransport(EventPoller::Ptr poller,Listener* listener) : poller(std::move(poller)), listener(listener) + { + MS_TRACE(); + env = DtlsEnvironment::Instance().shared_from_this(); + + /* Set SSL. */ + + this->ssl = SSL_new(env->sslCtx); + + if (!this->ssl) + { + LOG_OPENSSL_ERROR("SSL_new() failed"); + + goto error; + } + + // Set this as custom data. + SSL_set_ex_data(this->ssl, 0, static_cast(this)); + + this->sslBioFromNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioFromNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + SSL_free(this->ssl); + + goto error; + } + + this->sslBioToNetwork = BIO_new(BIO_s_mem()); + + if (!this->sslBioToNetwork) + { + LOG_OPENSSL_ERROR("BIO_new() failed"); + + BIO_free(this->sslBioFromNetwork); + SSL_free(this->ssl); + + goto error; + } + + SSL_set_bio(this->ssl, this->sslBioFromNetwork, this->sslBioToNetwork); + + // Set the MTU so that we don't send packets that are too large with no fragmentation. + SSL_set_mtu(this->ssl, DtlsMtu); + DTLS_set_link_mtu(this->ssl, DtlsMtu); + + // Set callback handler for setting DTLS timer interval. + DTLS_set_timer_cb(this->ssl, onSslDtlsTimer); + + return; + + error: + + // NOTE: At this point SSL_set_bio() was not called so we must free BIOs as + // well. + if (this->sslBioFromNetwork) + BIO_free(this->sslBioFromNetwork); + + if (this->sslBioToNetwork) + BIO_free(this->sslBioToNetwork); + + if (this->ssl) + SSL_free(this->ssl); + + // NOTE: If this is not catched by the caller the program will abort, but + // this should never happen. + MS_THROW_ERROR("DtlsTransport instance creation failed"); + } + + DtlsTransport::~DtlsTransport() + { + MS_TRACE(); + + if (IsRunning()) + { + // Send close alert to the peer. + SSL_shutdown(this->ssl); + SendPendingOutgoingDtlsData(); + } + + if (this->ssl) + { + SSL_free(this->ssl); + + this->ssl = nullptr; + this->sslBioFromNetwork = nullptr; + this->sslBioToNetwork = nullptr; + } + + // Close the DTLS timer. + this->timer = nullptr; + } + + void DtlsTransport::Dump() const + { + MS_TRACE(); + + std::string state{ "new" }; + std::string role{ "none " }; + + switch (this->state) + { + case DtlsState::CONNECTING: + state = "connecting"; + break; + case DtlsState::CONNECTED: + state = "connected"; + break; + case DtlsState::FAILED: + state = "failed"; + break; + case DtlsState::CLOSED: + state = "closed"; + break; + default:; + } + + switch (this->localRole) + { + case Role::AUTO: + role = "auto"; + break; + case Role::SERVER: + role = "server"; + break; + case Role::CLIENT: + role = "client"; + break; + default:; + } + + MS_DUMP(""); + MS_DUMP(" state : %s", state.c_str()); + MS_DUMP(" role : %s", role.c_str()); + MS_DUMP(" handshake done: : %s", this->handshakeDone ? "yes" : "no"); + MS_DUMP(""); + } + + void DtlsTransport::Run(Role localRole) + { + DebugL << ((localRole == RTC::DtlsTransport::Role::SERVER)? "Server" : "Client"); + + MS_TRACE(); + + MS_ASSERT( + localRole == Role::CLIENT || localRole == Role::SERVER, + "local DTLS role must be 'client' or 'server'"); + + Role previousLocalRole = this->localRole; + + if (localRole == previousLocalRole) + { + MS_ERROR("same local DTLS role provided, doing nothing"); + + return; + } + + // If the previous local DTLS role was 'client' or 'server' do reset. + if (previousLocalRole == Role::CLIENT || previousLocalRole == Role::SERVER) + { + MS_DEBUG_TAG(dtls, "resetting DTLS due to local role change"); + + Reset(); + } + + // Update local role. + this->localRole = localRole; + + // Set state and notify the listener. + this->state = DtlsState::CONNECTING; + this->listener->OnDtlsTransportConnecting(this); + + switch (this->localRole) + { + case Role::CLIENT: + { + MS_DEBUG_TAG(dtls, "running [role:client]"); + + SSL_set_connect_state(this->ssl); + SSL_do_handshake(this->ssl); + SendPendingOutgoingDtlsData(); + SetTimeout(); + + break; + } + + case Role::SERVER: + { + MS_DEBUG_TAG(dtls, "running [role:server]"); + + SSL_set_accept_state(this->ssl); + SSL_do_handshake(this->ssl); + + break; + } + + default: + { + MS_ABORT("invalid local DTLS role"); + } + } + } + + bool DtlsTransport::SetRemoteFingerprint(Fingerprint fingerprint) + { + MS_TRACE(); + + MS_ASSERT( + fingerprint.algorithm != FingerprintAlgorithm::NONE, "no fingerprint algorithm provided"); + + this->remoteFingerprint = fingerprint; + + // The remote fingerpring may have been set after DTLS handshake was done, + // so we may need to process it now. + if (this->handshakeDone && this->state != DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "handshake already done, processing it right now"); + + return ProcessHandshake(); + } + + return true; + } + + void DtlsTransport::ProcessDtlsData(const uint8_t* data, size_t len) + { + MS_TRACE(); + + int written; + int read; + + if (!IsRunning()) + { + MS_WARN_TAG(nullptr,"cannot process data while not running"); + return; + } + + // Write the received DTLS data into the sslBioFromNetwork. + written = + BIO_write(this->sslBioFromNetwork, static_cast(data), static_cast(len)); + + if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, + "OpenSSL BIO_write() wrote less (%zu bytes) than given data (%zu bytes)", + static_cast(written), + len); + } + + // Must call SSL_read() to process received DTLS data. + read = SSL_read(this->ssl, static_cast(DtlsTransport::sslReadBuffer), SslReadBufferSize); + + // Send data if it's ready. + SendPendingOutgoingDtlsData(); + + // Check SSL status and return if it is bad/closed. + if (!CheckStatus(read)) + return; + + // Set/update the DTLS timeout. + if (!SetTimeout()) + return; + + // Application data received. Notify to the listener. + if (read > 0) + { + // It is allowed to receive DTLS data even before validating remote fingerprint. + if (!this->handshakeDone) + { + MS_WARN_TAG(dtls, "ignoring application data received while DTLS handshake not done"); + + return; + } + + // Notify the listener. + this->listener->OnDtlsTransportApplicationDataReceived( + this, (uint8_t*)DtlsTransport::sslReadBuffer, static_cast(read)); + } + } + + void DtlsTransport::SendApplicationData(const uint8_t* data, size_t len) + { + MS_TRACE(); + + // We cannot send data to the peer if its remote fingerprint is not validated. + if (this->state != DtlsState::CONNECTED) + { + MS_WARN_TAG(dtls, "cannot send application data while DTLS is not fully connected"); + + return; + } + + if (len == 0) + { + MS_WARN_TAG(dtls, "ignoring 0 length data"); + + return; + } + + int written; + + written = SSL_write(this->ssl, static_cast(data), static_cast(len)); + + if (written < 0) + { + LOG_OPENSSL_ERROR("SSL_write() failed"); + + if (!CheckStatus(written)) + return; + } + else if (written != static_cast(len)) + { + MS_WARN_TAG( + dtls, "OpenSSL SSL_write() wrote less (%d bytes) than given data (%zu bytes)", written, len); + } + + // Send data. + SendPendingOutgoingDtlsData(); + } + + void DtlsTransport::Reset() + { + MS_TRACE(); + + int ret; + + if (!IsRunning()) + return; + + MS_WARN_TAG(dtls, "resetting DTLS transport"); + + // Stop the DTLS timer. + this->timer = nullptr; + + // We need to reset the SSL instance so we need to "shutdown" it, but we + // don't want to send a Close Alert to the peer, so just don't call + // SendPendingOutgoingDTLSData(). + SSL_shutdown(this->ssl); + + this->localRole = Role::NONE; + this->state = DtlsState::NEW; + this->handshakeDone = false; + this->handshakeDoneNow = false; + + // Reset SSL status. + // NOTE: For this to properly work, SSL_shutdown() must be called before. + // NOTE: This may fail if not enough DTLS handshake data has been received, + // but we don't care so just clear the error queue. + ret = SSL_clear(this->ssl); + + if (ret == 0) + ERR_clear_error(); + } + + inline bool DtlsTransport::CheckStatus(int returnCode) + { + MS_TRACE(); + + int err; + bool wasHandshakeDone = this->handshakeDone; + + err = SSL_get_error(this->ssl, returnCode); + + switch (err) + { + case SSL_ERROR_NONE: + break; + + case SSL_ERROR_SSL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SSL"); + break; + + case SSL_ERROR_WANT_READ: + break; + + case SSL_ERROR_WANT_WRITE: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_WRITE"); + break; + + case SSL_ERROR_WANT_X509_LOOKUP: + MS_DEBUG_TAG(dtls, "SSL status: SSL_ERROR_WANT_X509_LOOKUP"); + break; + + case SSL_ERROR_SYSCALL: + LOG_OPENSSL_ERROR("SSL status: SSL_ERROR_SYSCALL"); + break; + + case SSL_ERROR_ZERO_RETURN: + break; + + case SSL_ERROR_WANT_CONNECT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_CONNECT"); + break; + + case SSL_ERROR_WANT_ACCEPT: + MS_WARN_TAG(dtls, "SSL status: SSL_ERROR_WANT_ACCEPT"); + break; + + default: + MS_WARN_TAG(dtls, "SSL status: unknown error"); + } + + // Check if the handshake (or re-handshake) has been done right now. + if (this->handshakeDoneNow) + { + this->handshakeDoneNow = false; + this->handshakeDone = true; + + // Stop the timer. + this->timer = nullptr; + + // Process the handshake just once (ignore if DTLS renegotiation). + if (!wasHandshakeDone && this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE) + return ProcessHandshake(); + + return true; + } + // Check if the peer sent close alert or a fatal error happened. + else if (((SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN) != 0) || err == SSL_ERROR_SSL || err == SSL_ERROR_SYSCALL) + { + if (this->state == DtlsState::CONNECTED) + { + MS_DEBUG_TAG(dtls, "disconnected"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::CLOSED; + this->listener->OnDtlsTransportClosed(this); + } + else + { + MS_WARN_TAG(dtls, "connection failed"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + } + + return false; + } + else + { + return true; + } + } + + inline void DtlsTransport::SendPendingOutgoingDtlsData() + { + MS_TRACE(); + + if (BIO_eof(this->sslBioToNetwork)) + return; + + int64_t read; + char* data{ nullptr }; + + read = BIO_get_mem_data(this->sslBioToNetwork, &data); // NOLINT + + if (read <= 0) + return; + + MS_DEBUG_DEV("%" PRIu64 " bytes of DTLS data ready to sent to the peer", read); + + // Notify the listener. + this->listener->OnDtlsTransportSendData( + this, reinterpret_cast(data), static_cast(read)); + + // Clear the BIO buffer. + // NOTE: the (void) avoids the -Wunused-value warning. + (void)BIO_reset(this->sslBioToNetwork); + } + + inline bool DtlsTransport::SetTimeout() + { + MS_TRACE(); + + MS_ASSERT( + this->state == DtlsState::CONNECTING || this->state == DtlsState::CONNECTED, + "invalid DTLS state"); + + int64_t ret; + struct timeval dtlsTimeout{ 0, 0 }; + uint64_t timeoutMs; + + // NOTE: If ret == 0 then ignore the value in dtlsTimeout. + // NOTE: No DTLSv_1_2_get_timeout() or DTLS_get_timeout() in OpenSSL 1.1.0-dev. + ret = DTLSv1_get_timeout(this->ssl, static_cast(&dtlsTimeout)); // NOLINT + + if (ret == 0) + return true; + + timeoutMs = (dtlsTimeout.tv_sec * static_cast(1000)) + (dtlsTimeout.tv_usec / 1000); + + if (timeoutMs == 0) + { + return true; + } + else if (timeoutMs < 30000) + { + MS_DEBUG_DEV("DTLS timer set in %" PRIu64 "ms", timeoutMs); + + weak_ptr weak_self = shared_from_this(); + this->timer = std::make_shared(timeoutMs / 1000.0f, [weak_self](){ + auto strong_self = weak_self.lock(); + if(strong_self){ + strong_self->OnTimer(); + } + return true; + }, this->poller); + + return true; + } + // NOTE: Don't start the timer again if the timeout is greater than 30 seconds. + else + { + MS_WARN_TAG(dtls, "DTLS timeout too high (%" PRIu64 "ms), resetting DLTS", timeoutMs); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } + } + + inline bool DtlsTransport::ProcessHandshake() + { + MS_TRACE(); + + MS_ASSERT(this->handshakeDone, "handshake not done yet"); + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + + // Validate the remote fingerprint. + if (!CheckRemoteFingerprint()) + { + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } + + // Get the negotiated SRTP crypto suite. + RTC::SrtpSession::CryptoSuite srtpCryptoSuite = GetNegotiatedSrtpCryptoSuite(); + + if (srtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE) + { + // Extract the SRTP keys (will notify the listener with them). + ExtractSrtpKeys(srtpCryptoSuite); + + return true; + } + + // NOTE: We assume that "use_srtp" DTLS extension is required even if + // there is no audio/video. + MS_WARN_2TAGS(dtls, srtp, "SRTP crypto suite not negotiated"); + + Reset(); + + // Set state and notify the listener. + this->state = DtlsState::FAILED; + this->listener->OnDtlsTransportFailed(this); + + return false; + } + + inline bool DtlsTransport::CheckRemoteFingerprint() + { + MS_TRACE(); + + MS_ASSERT( + this->remoteFingerprint.algorithm != FingerprintAlgorithm::NONE, "remote fingerprint not set"); + + X509* certificate; + uint8_t binaryFingerprint[EVP_MAX_MD_SIZE]; + unsigned int size{ 0 }; + char hexFingerprint[(EVP_MAX_MD_SIZE * 3) + 1]; + const EVP_MD* hashFunction; + int ret; + + certificate = SSL_get_peer_certificate(this->ssl); + + if (!certificate) + { + MS_WARN_TAG(dtls, "no certificate was provided by the peer"); + + return false; + } + + switch (this->remoteFingerprint.algorithm) + { + case FingerprintAlgorithm::SHA1: + hashFunction = EVP_sha1(); + break; + + case FingerprintAlgorithm::SHA224: + hashFunction = EVP_sha224(); + break; + + case FingerprintAlgorithm::SHA256: + hashFunction = EVP_sha256(); + break; + + case FingerprintAlgorithm::SHA384: + hashFunction = EVP_sha384(); + break; + + case FingerprintAlgorithm::SHA512: + hashFunction = EVP_sha512(); + break; + + default: + MS_ABORT("unknown algorithm"); + } + + // Compare the remote fingerprint with the value given via signaling. + ret = X509_digest(certificate, hashFunction, binaryFingerprint, &size); + + if (ret == 0) + { + MS_ERROR("X509_digest() failed"); + + X509_free(certificate); + + return false; + } + + // Convert to hexadecimal format in uppercase with colons. + for (unsigned int i{ 0 }; i < size; ++i) + { + std::sprintf(hexFingerprint + (i * 3), "%.2X:", binaryFingerprint[i]); + } + hexFingerprint[(size * 3) - 1] = '\0'; + + if (this->remoteFingerprint.value != hexFingerprint) + { + MS_WARN_TAG( + dtls, + "fingerprint in the remote certificate (%s) does not match the announced one (%s)", + hexFingerprint, + this->remoteFingerprint.value.c_str()); + X509_free(certificate); + return false; + } + + MS_DEBUG_TAG(dtls, "valid remote fingerprint"); + + // Get the remote certificate in PEM format. + + BIO* bio = BIO_new(BIO_s_mem()); + + // Ensure the underlying BUF_MEM structure is also freed. + // NOTE: Avoid stupid "warning: value computed is not used [-Wunused-value]" since + // BIO_set_close() always returns 1. + (void)BIO_set_close(bio, BIO_CLOSE); + + ret = PEM_write_bio_X509(bio, certificate); + + if (ret != 1) + { + LOG_OPENSSL_ERROR("PEM_write_bio_X509() failed"); + + X509_free(certificate); + BIO_free(bio); + + return false; + } + + BUF_MEM* mem; + + BIO_get_mem_ptr(bio, &mem); // NOLINT[cppcoreguidelines-pro-type-cstyle-cast] + + if (!mem || !mem->data || mem->length == 0u) + { + LOG_OPENSSL_ERROR("BIO_get_mem_ptr() failed"); + + X509_free(certificate); + BIO_free(bio); + + return false; + } + + this->remoteCert = std::string(mem->data, mem->length); + + X509_free(certificate); + BIO_free(bio); + + return true; + } + + inline void DtlsTransport::ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite) + { + MS_TRACE(); + + size_t srtpKeyLength{ 0 }; + size_t srtpSaltLength{ 0 }; + size_t srtpMasterLength{ 0 }; + + switch (srtpCryptoSuite) + { + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_80: + case RTC::SrtpSession::CryptoSuite::AES_CM_128_HMAC_SHA1_32: + { + srtpKeyLength = SrtpMasterKeyLength; + srtpSaltLength = SrtpMasterSaltLength; + srtpMasterLength = SrtpMasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_256_GCM: + { + srtpKeyLength = SrtpAesGcm256MasterKeyLength; + srtpSaltLength = SrtpAesGcm256MasterSaltLength; + srtpMasterLength = SrtpAesGcm256MasterLength; + + break; + } + + case RTC::SrtpSession::CryptoSuite::AEAD_AES_128_GCM: + { + srtpKeyLength = SrtpAesGcm128MasterKeyLength; + srtpSaltLength = SrtpAesGcm128MasterSaltLength; + srtpMasterLength = SrtpAesGcm128MasterLength; + + break; + } + + default: + { + MS_ABORT("unknown SRTP crypto suite"); + } + } + + auto* srtpMaterial = new uint8_t[srtpMasterLength * 2]; + uint8_t* srtpLocalKey{ nullptr }; + uint8_t* srtpLocalSalt{ nullptr }; + uint8_t* srtpRemoteKey{ nullptr }; + uint8_t* srtpRemoteSalt{ nullptr }; + auto* srtpLocalMasterKey = new uint8_t[srtpMasterLength]; + auto* srtpRemoteMasterKey = new uint8_t[srtpMasterLength]; + int ret; + + ret = SSL_export_keying_material( + this->ssl, srtpMaterial, srtpMasterLength * 2, "EXTRACTOR-dtls_srtp", 19, nullptr, 0, 0); + + MS_ASSERT(ret != 0, "SSL_export_keying_material() failed"); + + switch (this->localRole) + { + case Role::SERVER: + { + srtpRemoteKey = srtpMaterial; + srtpLocalKey = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteSalt + srtpSaltLength; + + break; + } + + case Role::CLIENT: + { + srtpLocalKey = srtpMaterial; + srtpRemoteKey = srtpLocalKey + srtpKeyLength; + srtpLocalSalt = srtpRemoteKey + srtpKeyLength; + srtpRemoteSalt = srtpLocalSalt + srtpSaltLength; + + break; + } + + default: + { + MS_ABORT("no DTLS role set"); + } + } + + // Create the SRTP local master key. + std::memcpy(srtpLocalMasterKey, srtpLocalKey, srtpKeyLength); + std::memcpy(srtpLocalMasterKey + srtpKeyLength, srtpLocalSalt, srtpSaltLength); + // Create the SRTP remote master key. + std::memcpy(srtpRemoteMasterKey, srtpRemoteKey, srtpKeyLength); + std::memcpy(srtpRemoteMasterKey + srtpKeyLength, srtpRemoteSalt, srtpSaltLength); + + // Set state and notify the listener. + this->state = DtlsState::CONNECTED; + this->listener->OnDtlsTransportConnected( + this, + srtpCryptoSuite, + srtpLocalMasterKey, + srtpMasterLength, + srtpRemoteMasterKey, + srtpMasterLength, + this->remoteCert); + + delete[] srtpMaterial; + delete[] srtpLocalMasterKey; + delete[] srtpRemoteMasterKey; + } + + inline RTC::SrtpSession::CryptoSuite DtlsTransport::GetNegotiatedSrtpCryptoSuite() + { + MS_TRACE(); + + RTC::SrtpSession::CryptoSuite negotiatedSrtpCryptoSuite = RTC::SrtpSession::CryptoSuite::NONE; + + // Ensure that the SRTP crypto suite has been negotiated. + // NOTE: This is a OpenSSL type. + SRTP_PROTECTION_PROFILE* sslSrtpCryptoSuite = SSL_get_selected_srtp_profile(this->ssl); + + if (!sslSrtpCryptoSuite) + return negotiatedSrtpCryptoSuite; + + // Get the negotiated SRTP crypto suite. + for (auto& srtpCryptoSuite : DtlsTransport::srtpCryptoSuites) + { + SrtpCryptoSuiteMapEntry* cryptoSuiteEntry = std::addressof(srtpCryptoSuite); + + if (std::strcmp(sslSrtpCryptoSuite->name, cryptoSuiteEntry->name) == 0) + { + MS_DEBUG_2TAGS(dtls, srtp, "chosen SRTP crypto suite: %s", cryptoSuiteEntry->name); + + negotiatedSrtpCryptoSuite = cryptoSuiteEntry->cryptoSuite; + } + } + + MS_ASSERT( + negotiatedSrtpCryptoSuite != RTC::SrtpSession::CryptoSuite::NONE, + "chosen SRTP crypto suite is not an available one"); + + return negotiatedSrtpCryptoSuite; + } + + inline void DtlsTransport::OnSslInfo(int where, int ret) + { + MS_TRACE(); + + int w = where & -SSL_ST_MASK; + const char* role; + + if ((w & SSL_ST_CONNECT) != 0) + role = "client"; + else if ((w & SSL_ST_ACCEPT) != 0) + role = "server"; + else + role = "undefined"; + + if ((where & SSL_CB_LOOP) != 0) + { + MS_DEBUG_TAG(dtls, "[role:%s, action:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_ALERT) != 0) + { + const char* alertType; + + switch (*SSL_alert_type_string(ret)) + { + case 'W': + alertType = "warning"; + break; + + case 'F': + alertType = "fatal"; + break; + + default: + alertType = "undefined"; + } + + if ((where & SSL_CB_READ) != 0) + { + MS_WARN_TAG(dtls, "received DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else if ((where & SSL_CB_WRITE) != 0) + { + MS_DEBUG_TAG(dtls, "sending DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + else + { + MS_DEBUG_TAG(dtls, "DTLS %s alert: %s", alertType, SSL_alert_desc_string_long(ret)); + } + } + else if ((where & SSL_CB_EXIT) != 0) + { + if (ret == 0) + MS_DEBUG_TAG(dtls, "[role:%s, failed:'%s']", role, SSL_state_string_long(this->ssl)); + else if (ret < 0) + MS_DEBUG_TAG(dtls, "role: %s, waiting:'%s']", role, SSL_state_string_long(this->ssl)); + } + else if ((where & SSL_CB_HANDSHAKE_START) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake start"); + } + else if ((where & SSL_CB_HANDSHAKE_DONE) != 0) + { + MS_DEBUG_TAG(dtls, "DTLS handshake done"); + + this->handshakeDoneNow = true; + } + + // NOTE: checking SSL_get_shutdown(this->ssl) & SSL_RECEIVED_SHUTDOWN here upon + // receipt of a close alert does not work (the flag is set after this callback). + } + + inline void DtlsTransport::OnTimer() + { + MS_TRACE(); + + // Workaround for https://github.com/openssl/openssl/issues/7998. + if (this->handshakeDone) + { + // MS_DEBUG_DEV("handshake is done so return"); + return; + } + + DTLSv1_handle_timeout(this->ssl); + + // If required, send DTLS data. + SendPendingOutgoingDtlsData(); + + // Set the DTLS timer again. + SetTimeout(); + } +} // namespace RTC diff --git a/webrtc/DtlsTransport.hpp b/webrtc/DtlsTransport.hpp index 53a1981d..48198358 100644 --- a/webrtc/DtlsTransport.hpp +++ b/webrtc/DtlsTransport.hpp @@ -1,254 +1,254 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef MS_RTC_DTLS_TRANSPORT_HPP -#define MS_RTC_DTLS_TRANSPORT_HPP - -#include "SrtpSession.hpp" -#include -#include -#include -#include -#include -#include -#include "Poller/Timer.h" -#include "Poller/EventPoller.h" -using namespace toolkit; - -namespace RTC -{ - class DtlsTransport : public std::enable_shared_from_this - { - public: - enum class DtlsState - { - NEW = 1, - CONNECTING, - CONNECTED, - FAILED, - CLOSED - }; - - public: - enum class Role - { - NONE = 0, - AUTO = 1, - CLIENT, - SERVER - }; - - public: - enum class FingerprintAlgorithm - { - NONE = 0, - SHA1 = 1, - SHA224, - SHA256, - SHA384, - SHA512 - }; - - public: - struct Fingerprint - { - FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; - std::string value; - }; - - private: - struct SrtpCryptoSuiteMapEntry - { - RTC::SrtpSession::CryptoSuite cryptoSuite; - const char* name; - }; - - class DtlsEnvironment : public std::enable_shared_from_this - { - public: - using Ptr = std::shared_ptr; - ~DtlsEnvironment(); - static DtlsEnvironment& Instance(); - - private: - DtlsEnvironment(); - void GenerateCertificateAndPrivateKey(); - bool ReadCertificateAndPrivateKeyFromContext(SSL_CTX *ctx); - void CreateSslCtx(); - void GenerateFingerprints(); - - public: - X509* certificate{ nullptr }; - EVP_PKEY* privateKey{ nullptr }; - SSL_CTX* sslCtx{ nullptr }; - std::vector localFingerprints; - }; - - public: - class Listener - { - public: - // DTLS is in the process of negotiating a secure connection. Incoming - // media can flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; - // DTLS has completed negotiation of a secure connection (including DTLS-SRTP - // and remote fingerprint verification). Outgoing media can now flow through. - // NOTE: The caller MUST NOT call any method during this callback. - virtual void OnDtlsTransportConnected( - const RTC::DtlsTransport* dtlsTransport, - RTC::SrtpSession::CryptoSuite srtpCryptoSuite, - uint8_t* srtpLocalKey, - size_t srtpLocalKeyLen, - uint8_t* srtpRemoteKey, - size_t srtpRemoteKeyLen, - std::string& remoteCert) = 0; - // The DTLS connection has been closed as the result of an error (such as a - // DTLS alert or a failure to validate the remote fingerprint). - virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; - // The DTLS connection has been closed due to receipt of a close_notify alert. - virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; - // Need to send DTLS data to the peer. - virtual void OnDtlsTransportSendData( - const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; - // DTLS application data received. - virtual void OnDtlsTransportApplicationDataReceived( - const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; - }; - - public: - static Role StringToRole(const std::string& role) - { - auto it = DtlsTransport::string2Role.find(role); - - if (it != DtlsTransport::string2Role.end()) - return it->second; - else - return DtlsTransport::Role::NONE; - } - static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) - { - auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); - - if (it != DtlsTransport::string2FingerprintAlgorithm.end()) - return it->second; - else - return DtlsTransport::FingerprintAlgorithm::NONE; - } - static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) - { - auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); - - return it->second; - } - static bool IsDtls(const uint8_t* data, size_t len) - { - // clang-format off - return ( - // Minimum DTLS record length is 13 bytes. - (len >= 13) && - // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes - (data[0] > 19 && data[0] < 64) - ); - // clang-format on - } - - private: - static std::map string2Role; - static std::map string2FingerprintAlgorithm; - static std::map fingerprintAlgorithm2String; - static std::vector srtpCryptoSuites; - - public: - DtlsTransport(EventPoller::Ptr poller, Listener* listener); - ~DtlsTransport(); - - public: - void Dump() const; - void Run(Role localRole); - std::vector& GetLocalFingerprints() const - { - return env->localFingerprints; - } - bool SetRemoteFingerprint(Fingerprint fingerprint); - void ProcessDtlsData(const uint8_t* data, size_t len); - DtlsState GetState() const - { - return this->state; - } - Role GetLocalRole() const - { - return this->localRole; - } - void SendApplicationData(const uint8_t* data, size_t len); - - private: - bool IsRunning() const - { - switch (this->state) - { - case DtlsState::NEW: - return false; - case DtlsState::CONNECTING: - case DtlsState::CONNECTED: - return true; - case DtlsState::FAILED: - case DtlsState::CLOSED: - return false; - } - - // Make GCC 4.9 happy. - return false; - } - void Reset(); - bool CheckStatus(int returnCode); - void SendPendingOutgoingDtlsData(); - bool SetTimeout(); - bool ProcessHandshake(); - bool CheckRemoteFingerprint(); - void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); - RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); - - private: - void OnSslInfo(int where, int ret); - void OnTimer(); - - private: - DtlsEnvironment::Ptr env; - EventPoller::Ptr poller; - // Passed by argument. - Listener* listener{ nullptr }; - // Allocated by this. - SSL* ssl{ nullptr }; - BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. - BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. - Timer::Ptr timer; - // Others. - DtlsState state{ DtlsState::NEW }; - Role localRole{ Role::NONE }; - Fingerprint remoteFingerprint; - bool handshakeDone{ false }; - bool handshakeDoneNow{ false }; - std::string remoteCert; - //最大不超过mtu - static constexpr int SslReadBufferSize{ 2000 }; - uint8_t sslReadBuffer[SslReadBufferSize]; -}; -} // namespace RTC - -#endif +/** +ISC License + +Copyright © 2015, Iñaki Baz Castillo + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef MS_RTC_DTLS_TRANSPORT_HPP +#define MS_RTC_DTLS_TRANSPORT_HPP + +#include "SrtpSession.hpp" +#include +#include +#include +#include +#include +#include +#include "Poller/Timer.h" +#include "Poller/EventPoller.h" + +namespace RTC +{ + class DtlsTransport : public std::enable_shared_from_this + { + public: + using Ptr = std::shared_ptr; + enum class DtlsState + { + NEW = 1, + CONNECTING, + CONNECTED, + FAILED, + CLOSED + }; + + public: + enum class Role + { + NONE = 0, + AUTO = 1, + CLIENT, + SERVER + }; + + public: + enum class FingerprintAlgorithm + { + NONE = 0, + SHA1 = 1, + SHA224, + SHA256, + SHA384, + SHA512 + }; + + public: + struct Fingerprint + { + FingerprintAlgorithm algorithm{ FingerprintAlgorithm::NONE }; + std::string value; + }; + + private: + struct SrtpCryptoSuiteMapEntry + { + RTC::SrtpSession::CryptoSuite cryptoSuite; + const char* name; + }; + + class DtlsEnvironment : public std::enable_shared_from_this + { + public: + using Ptr = std::shared_ptr; + ~DtlsEnvironment(); + static DtlsEnvironment& Instance(); + + private: + DtlsEnvironment(); + void GenerateCertificateAndPrivateKey(); + bool ReadCertificateAndPrivateKeyFromContext(SSL_CTX *ctx); + void CreateSslCtx(); + void GenerateFingerprints(); + + public: + X509* certificate{ nullptr }; + EVP_PKEY* privateKey{ nullptr }; + SSL_CTX* sslCtx{ nullptr }; + std::vector localFingerprints; + }; + + public: + class Listener + { + public: + // DTLS is in the process of negotiating a secure connection. Incoming + // media can flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnecting(const RTC::DtlsTransport* dtlsTransport) = 0; + // DTLS has completed negotiation of a secure connection (including DTLS-SRTP + // and remote fingerprint verification). Outgoing media can now flow through. + // NOTE: The caller MUST NOT call any method during this callback. + virtual void OnDtlsTransportConnected( + const RTC::DtlsTransport* dtlsTransport, + RTC::SrtpSession::CryptoSuite srtpCryptoSuite, + uint8_t* srtpLocalKey, + size_t srtpLocalKeyLen, + uint8_t* srtpRemoteKey, + size_t srtpRemoteKeyLen, + std::string& remoteCert) = 0; + // The DTLS connection has been closed as the result of an error (such as a + // DTLS alert or a failure to validate the remote fingerprint). + virtual void OnDtlsTransportFailed(const RTC::DtlsTransport* dtlsTransport) = 0; + // The DTLS connection has been closed due to receipt of a close_notify alert. + virtual void OnDtlsTransportClosed(const RTC::DtlsTransport* dtlsTransport) = 0; + // Need to send DTLS data to the peer. + virtual void OnDtlsTransportSendData( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + // DTLS application data received. + virtual void OnDtlsTransportApplicationDataReceived( + const RTC::DtlsTransport* dtlsTransport, const uint8_t* data, size_t len) = 0; + }; + + public: + static Role StringToRole(const std::string& role) + { + auto it = DtlsTransport::string2Role.find(role); + + if (it != DtlsTransport::string2Role.end()) + return it->second; + else + return DtlsTransport::Role::NONE; + } + static FingerprintAlgorithm GetFingerprintAlgorithm(const std::string& fingerprint) + { + auto it = DtlsTransport::string2FingerprintAlgorithm.find(fingerprint); + + if (it != DtlsTransport::string2FingerprintAlgorithm.end()) + return it->second; + else + return DtlsTransport::FingerprintAlgorithm::NONE; + } + static std::string& GetFingerprintAlgorithmString(FingerprintAlgorithm fingerprint) + { + auto it = DtlsTransport::fingerprintAlgorithm2String.find(fingerprint); + + return it->second; + } + static bool IsDtls(const uint8_t* data, size_t len) + { + // clang-format off + return ( + // Minimum DTLS record length is 13 bytes. + (len >= 13) && + // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes + (data[0] > 19 && data[0] < 64) + ); + // clang-format on + } + + private: + static std::map string2Role; + static std::map string2FingerprintAlgorithm; + static std::map fingerprintAlgorithm2String; + static std::vector srtpCryptoSuites; + + public: + DtlsTransport(toolkit::EventPoller::Ptr poller, Listener* listener); + ~DtlsTransport(); + + public: + void Dump() const; + void Run(Role localRole); + std::vector& GetLocalFingerprints() const + { + return env->localFingerprints; + } + bool SetRemoteFingerprint(Fingerprint fingerprint); + void ProcessDtlsData(const uint8_t* data, size_t len); + DtlsState GetState() const + { + return this->state; + } + Role GetLocalRole() const + { + return this->localRole; + } + void SendApplicationData(const uint8_t* data, size_t len); + + private: + bool IsRunning() const + { + switch (this->state) + { + case DtlsState::NEW: + return false; + case DtlsState::CONNECTING: + case DtlsState::CONNECTED: + return true; + case DtlsState::FAILED: + case DtlsState::CLOSED: + return false; + } + + // Make GCC 4.9 happy. + return false; + } + void Reset(); + bool CheckStatus(int returnCode); + void SendPendingOutgoingDtlsData(); + bool SetTimeout(); + bool ProcessHandshake(); + bool CheckRemoteFingerprint(); + void ExtractSrtpKeys(RTC::SrtpSession::CryptoSuite srtpCryptoSuite); + RTC::SrtpSession::CryptoSuite GetNegotiatedSrtpCryptoSuite(); + + private: + void OnSslInfo(int where, int ret); + void OnTimer(); + + private: + DtlsEnvironment::Ptr env; + toolkit::EventPoller::Ptr poller; + // Passed by argument. + Listener* listener{ nullptr }; + // Allocated by this. + SSL* ssl{ nullptr }; + BIO* sslBioFromNetwork{ nullptr }; // The BIO from which ssl reads. + BIO* sslBioToNetwork{ nullptr }; // The BIO in which ssl writes. + toolkit::Timer::Ptr timer; + // Others. + DtlsState state{ DtlsState::NEW }; + Role localRole{ Role::NONE }; + Fingerprint remoteFingerprint; + bool handshakeDone{ false }; + bool handshakeDoneNow{ false }; + std::string remoteCert; + //最大不超过mtu + static constexpr int SslReadBufferSize{ 2000 }; + uint8_t sslReadBuffer[SslReadBufferSize]; +}; +} // namespace RTC + +#endif diff --git a/webrtc/IceServer.cpp b/webrtc/IceServer.cpp deleted file mode 100644 index d4530d10..00000000 --- a/webrtc/IceServer.cpp +++ /dev/null @@ -1,528 +0,0 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#define MS_CLASS "RTC::IceServer" -// #define MS_LOG_DEV_LEVEL 3 - -#include -#include "IceServer.hpp" - -namespace RTC -{ - /* Static. */ - /* Instance methods. */ - - IceServer::IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password) - : listener(listener), usernameFragment(usernameFragment), password(password) - { - MS_TRACE(); - } - - void IceServer::ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple) - { - MS_TRACE(); - - // Must be a Binding method. - if (packet->GetMethod() != RTC::StunPacket::Method::BINDING) - { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) - { - MS_WARN_TAG( - ice, - "unknown method %#.3x in STUN Request => 400", - static_cast(packet->GetMethod())); - - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); - - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - } - else - { - MS_WARN_TAG( - ice, - "ignoring STUN Indication or Response with unknown method %#.3x", - static_cast(packet->GetMethod())); - } - - return; - } - - // Must use FINGERPRINT (optional for ICE STUN indications). - if (!packet->HasFingerprint() && packet->GetClass() != RTC::StunPacket::Class::INDICATION) - { - if (packet->GetClass() == RTC::StunPacket::Class::REQUEST) - { - MS_WARN_TAG(ice, "STUN Binding Request without FINGERPRINT => 400"); - - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); - - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - } - else - { - MS_WARN_TAG(ice, "ignoring STUN Binding Response without FINGERPRINT"); - } - - return; - } - - switch (packet->GetClass()) - { - case RTC::StunPacket::Class::REQUEST: - { - // USERNAME, MESSAGE-INTEGRITY and PRIORITY are required. - if (!packet->HasMessageIntegrity() || (packet->GetPriority() == 0u) || packet->GetUsername().empty()) - { - MS_WARN_TAG(ice, "mising required attributes in STUN Binding Request => 400"); - - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); - - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - - return; - } - - // Check authentication. - switch (packet->CheckAuthentication(this->usernameFragment, this->password)) - { - case RTC::StunPacket::Authentication::OK: - { - if (!this->oldPassword.empty()) - { - MS_DEBUG_TAG(ice, "new ICE credentials applied"); - - this->oldUsernameFragment.clear(); - this->oldPassword.clear(); - } - - break; - } - - case RTC::StunPacket::Authentication::UNAUTHORIZED: - { - // We may have changed our usernameFragment and password, so check - // the old ones. - // clang-format off - if ( - !this->oldUsernameFragment.empty() && - !this->oldPassword.empty() && - packet->CheckAuthentication(this->oldUsernameFragment, this->oldPassword) == RTC::StunPacket::Authentication::OK - ) - // clang-format on - { - MS_DEBUG_TAG(ice, "using old ICE credentials"); - - break; - } - - MS_WARN_TAG(ice, "wrong authentication in STUN Binding Request => 401"); - - // Reply 401. - RTC::StunPacket* response = packet->CreateErrorResponse(401); - - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - - return; - } - - case RTC::StunPacket::Authentication::BAD_REQUEST: - { - MS_WARN_TAG(ice, "cannot check authentication in STUN Binding Request => 400"); - - // Reply 400. - RTC::StunPacket* response = packet->CreateErrorResponse(400); - - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - - return; - } - } - -#if 0 - // The remote peer must be ICE controlling. - if (packet->GetIceControlled()) - { - MS_WARN_TAG(ice, "peer indicates ICE-CONTROLLED in STUN Binding Request => 487"); - - // Reply 487 (Role Conflict). - RTC::StunPacket* response = packet->CreateErrorResponse(487); - - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - - return; - } - -#endif - - //MS_DEBUG_DEV( - // "processing STUN Binding Request [Priority:%" PRIu32 ", UseCandidate:%s]", - // static_cast(packet->GetPriority()), - // packet->HasUseCandidate() ? "true" : "false"); - - // Create a success response. - RTC::StunPacket* response = packet->CreateSuccessResponse(); - - sockaddr_storage peerAddr; - socklen_t addr_len = sizeof(peerAddr); - getpeername(tuple->getSock()->rawFD(), (struct sockaddr *)&peerAddr, &addr_len); - - // Add XOR-MAPPED-ADDRESS. - response->SetXorMappedAddress((struct sockaddr *)&peerAddr); - - // Authenticate the response. - if (this->oldPassword.empty()) - response->Authenticate(this->password); - else - response->Authenticate(this->oldPassword); - - // Send back. - response->Serialize(StunSerializeBuffer); - this->listener->OnIceServerSendStunPacket(this, response, tuple); - - delete response; - - // Handle the tuple. - HandleTuple(tuple, packet->HasUseCandidate()); - - break; - } - - case RTC::StunPacket::Class::INDICATION: - { - MS_DEBUG_TAG(ice, "STUN Binding Indication processed"); - - break; - } - - case RTC::StunPacket::Class::SUCCESS_RESPONSE: - { - MS_DEBUG_TAG(ice, "STUN Binding Success Response processed"); - - break; - } - - case RTC::StunPacket::Class::ERROR_RESPONSE: - { - MS_DEBUG_TAG(ice, "STUN Binding Error Response processed"); - - break; - } - } - } - - bool IceServer::IsValidTuple(const RTC::TransportTuple* tuple) const - { - MS_TRACE(); - - return HasTuple(tuple) != nullptr; - } - - void IceServer::RemoveTuple(RTC::TransportTuple* tuple) - { - MS_TRACE(); - - RTC::TransportTuple* removedTuple{ nullptr }; - - // Find the removed tuple. - auto it = this->tuples.begin(); - - for (; it != this->tuples.end(); ++it) - { - RTC::TransportTuple* storedTuple = *it; - - if (storedTuple == tuple) - { - removedTuple = storedTuple; - - break; - } - } - - // If not found, ignore. - if (!removedTuple) - return; - - // Remove from the list of tuples. - this->tuples.erase(it); - - // If this is not the selected tuple, stop here. - if (removedTuple != this->selectedTuple) - return; - - // Otherwise this was the selected tuple. - this->selectedTuple = nullptr; - - // Mark the first tuple as selected tuple (if any). - if (!this->tuples.empty()) - { - SetSelectedTuple(this->tuples.front()); - } - // Or just emit 'disconnected'. - else - { - // Update state. - this->state = IceState::DISCONNECTED; - // Notify the listener. - this->listener->OnIceServerDisconnected(this); - } - } - - void IceServer::ForceSelectedTuple(const RTC::TransportTuple* tuple) - { - MS_TRACE(); - - MS_ASSERT( - this->selectedTuple, "cannot force the selected tuple if there was not a selected tuple"); - - auto* storedTuple = HasTuple(tuple); - - MS_ASSERT( - storedTuple, - "cannot force the selected tuple if the given tuple was not already a valid tuple"); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - } - - void IceServer::HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate) - { - MS_TRACE(); - - switch (this->state) - { - case IceState::NEW: - { - // There should be no tuples. - MS_ASSERT( - this->tuples.empty(), "state is 'new' but there are %zu tuples", this->tuples.size()); - - // There shouldn't be a selected tuple. - MS_ASSERT(!this->selectedTuple, "state is 'new' but there is selected tuple"); - - if (!hasUseCandidate) - { - MS_DEBUG_TAG(ice, "transition from state 'new' to 'connected'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::CONNECTED; - // Notify the listener. - this->listener->OnIceServerConnected(this); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'new' to 'completed'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::DISCONNECTED: - { - // There should be no tuples. - MS_ASSERT( - this->tuples.empty(), - "state is 'disconnected' but there are %zu tuples", - this->tuples.size()); - - // There shouldn't be a selected tuple. - MS_ASSERT(!this->selectedTuple, "state is 'disconnected' but there is selected tuple"); - - if (!hasUseCandidate) - { - MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'connected'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::CONNECTED; - // Notify the listener. - this->listener->OnIceServerConnected(this); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'disconnected' to 'completed'"); - - // Store the tuple. - auto* storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::CONNECTED: - { - // There should be some tuples. - MS_ASSERT(!this->tuples.empty(), "state is 'connected' but there are no tuples"); - - // There should be a selected tuple. - MS_ASSERT(this->selectedTuple, "state is 'connected' but there is not selected tuple"); - - if (!hasUseCandidate) - { - // If a new tuple store it. - if (!HasTuple(tuple)) - AddTuple(tuple); - } - else - { - MS_DEBUG_TAG(ice, "transition from state 'connected' to 'completed'"); - - auto* storedTuple = HasTuple(tuple); - - // If a new tuple store it. - if (!storedTuple) - storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - // Update state. - this->state = IceState::COMPLETED; - // Notify the listener. - this->listener->OnIceServerCompleted(this); - } - - break; - } - - case IceState::COMPLETED: - { - // There should be some tuples. - MS_ASSERT(!this->tuples.empty(), "state is 'completed' but there are no tuples"); - - // There should be a selected tuple. - MS_ASSERT(this->selectedTuple, "state is 'completed' but there is not selected tuple"); - - if (!hasUseCandidate) - { - // If a new tuple store it. - if (!HasTuple(tuple)) - AddTuple(tuple); - } - else - { - auto* storedTuple = HasTuple(tuple); - - // If a new tuple store it. - if (!storedTuple) - storedTuple = AddTuple(tuple); - - // Mark it as selected tuple. - SetSelectedTuple(storedTuple); - } - - break; - } - } - } - - inline RTC::TransportTuple* IceServer::AddTuple(RTC::TransportTuple* tuple) - { - MS_TRACE(); - - // Add the new tuple at the beginning of the list. - this->tuples.push_front(tuple); - - // Return the address of the inserted tuple. - return tuple; - } - - inline RTC::TransportTuple* IceServer::HasTuple(const RTC::TransportTuple* tuple) const - { - MS_TRACE(); - - // If there is no selected tuple yet then we know that the tuples list - // is empty. - if (!this->selectedTuple) - return nullptr; - - // Check the current selected tuple. - if (selectedTuple == tuple) - return this->selectedTuple; - - // Otherwise check other stored tuples. - for (const auto& it : this->tuples) - { - auto& storedTuple = it; - if (storedTuple == tuple) - return storedTuple; - } - - return nullptr; - } - - inline void IceServer::SetSelectedTuple(RTC::TransportTuple* storedTuple) - { - MS_TRACE(); - - // If already the selected tuple do nothing. - if (storedTuple == this->selectedTuple) - return; - - this->selectedTuple = storedTuple; - this->lastSelectedTuple = std::static_pointer_cast(storedTuple->shared_from_this()); - - // Notify the listener. - this->listener->OnIceServerSelectedTuple(this, this->selectedTuple); - } -} // namespace RTC diff --git a/webrtc/IceServer.hpp b/webrtc/IceServer.hpp deleted file mode 100644 index 316d32af..00000000 --- a/webrtc/IceServer.hpp +++ /dev/null @@ -1,138 +0,0 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef MS_RTC_ICE_SERVER_HPP -#define MS_RTC_ICE_SERVER_HPP - -#include "StunPacket.hpp" -#include "Network/Session.h" -#include "logger.h" -#include "Utils.hpp" -#include -#include -#include -#include - -namespace RTC -{ - using TransportTuple = toolkit::Session; - class IceServer - { - public: - enum class IceState - { - NEW = 1, - CONNECTED, - COMPLETED, - DISCONNECTED - }; - - public: - class Listener - { - public: - virtual ~Listener() = default; - - public: - /** - * These callbacks are guaranteed to be called before ProcessStunPacket() - * returns, so the given pointers are still usable. - */ - virtual void OnIceServerSendStunPacket( - const RTC::IceServer* iceServer, const RTC::StunPacket* packet, RTC::TransportTuple* tuple) = 0; - virtual void OnIceServerSelectedTuple( - const RTC::IceServer* iceServer, RTC::TransportTuple* tuple) = 0; - virtual void OnIceServerConnected(const RTC::IceServer* iceServer) = 0; - virtual void OnIceServerCompleted(const RTC::IceServer* iceServer) = 0; - virtual void OnIceServerDisconnected(const RTC::IceServer* iceServer) = 0; - }; - - public: - IceServer(Listener* listener, const std::string& usernameFragment, const std::string& password); - - public: - void ProcessStunPacket(RTC::StunPacket* packet, RTC::TransportTuple* tuple); - const std::string& GetUsernameFragment() const - { - return this->usernameFragment; - } - const std::string& GetPassword() const - { - return this->password; - } - IceState GetState() const - { - return this->state; - } - RTC::TransportTuple* GetSelectedTuple(bool try_last_tuple = false) const - { - return try_last_tuple ? this->lastSelectedTuple.lock().get() : this->selectedTuple; - } - void SetUsernameFragment(const std::string& usernameFragment) - { - this->oldUsernameFragment = this->usernameFragment; - this->usernameFragment = usernameFragment; - } - void SetPassword(const std::string& password) - { - this->oldPassword = this->password; - this->password = password; - } - bool IsValidTuple(const RTC::TransportTuple* tuple) const; - void RemoveTuple(RTC::TransportTuple* tuple); - // This should be just called in 'connected' or completed' state - // and the given tuple must be an already valid tuple. - void ForceSelectedTuple(const RTC::TransportTuple* tuple); - - const std::list& GetTuples() const { return tuples; } - - private: - void HandleTuple(RTC::TransportTuple* tuple, bool hasUseCandidate); - /** - * Store the given tuple and return its stored address. - */ - RTC::TransportTuple* AddTuple(RTC::TransportTuple* tuple); - /** - * If the given tuple exists return its stored address, nullptr otherwise. - */ - RTC::TransportTuple* HasTuple(const RTC::TransportTuple* tuple) const; - /** - * Set the given tuple as the selected tuple. - * NOTE: The given tuple MUST be already stored within the list. - */ - void SetSelectedTuple(RTC::TransportTuple* storedTuple); - - private: - // Passed by argument. - Listener* listener{ nullptr }; - // Others. - std::string usernameFragment; - std::string password; - std::string oldUsernameFragment; - std::string oldPassword; - IceState state{ IceState::NEW }; - std::list tuples; - RTC::TransportTuple *selectedTuple { nullptr }; - std::weak_ptr lastSelectedTuple; - //最大不超过mtu - static constexpr size_t StunSerializeBufferSize{ 1600 }; - uint8_t StunSerializeBuffer[StunSerializeBufferSize]; - }; -} // namespace RTC - -#endif diff --git a/webrtc/IceSession.cpp b/webrtc/IceSession.cpp new file mode 100644 index 00000000..207d555b --- /dev/null +++ b/webrtc/IceSession.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "IceSession.hpp" +#include "Util/util.h" +#include "Common/config.h" +#include "WebRtcTransport.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +static IceSession::Ptr queryIceTransport(uint8_t *data, size_t size) { + auto packet = RTC::StunPacket::parse((const uint8_t *)data, size); + if (!packet) { + WarnL << "parse stun error"; + return nullptr; + } + + auto username = packet->getUsername(); + return IceSessionManager::Instance().getItem(username); +} + +//////////// IceSession ////////////////////////// +IceSession::IceSession(const Socket::Ptr &sock) : Session(sock) { + TraceL << getIdentifier(); + _over_tcp = sock->sockType() == SockNum::Sock_TCP; + GET_CONFIG(string, iceUfrag, Rtc::kIceUfrag); + GET_CONFIG(string, icePwd, Rtc::kIcePwd); + _ice_transport = std::make_shared(this, iceUfrag, icePwd, getPoller()); + _ice_transport->initialize(); +} + +IceSession::~IceSession() { + TraceL << getIdentifier(); +} + +EventPoller::Ptr IceSession::queryPoller(const Buffer::Ptr &buffer) { + auto transport = queryIceTransport((uint8_t *)buffer->data(), buffer->size()); + return transport ? transport->getPoller() : nullptr; +} + +void IceSession::onRecv(const Buffer::Ptr &buffer) { + // TraceL; + if (_over_tcp) { + input(buffer->data(), buffer->size()); + } + else{ + onRecv_l(buffer->data(), buffer->size()); + } +} + +void IceSession::onRecv_l(const char* buffer, size_t size) { + if (!_session_pair) { + _session_pair = std::make_shared(shared_from_this()); + } + _ice_transport->processSocketData((const uint8_t *)buffer, size, _session_pair); +} + +void IceSession::onError(const SockException &err) { + InfoL; + // 消除循环引用 + _session_pair = nullptr; +} + +void IceSession::onManager() { +} + +ssize_t IceSession::onRecvHeader(const char *data, size_t len) { + onRecv_l(data + 2, len - 2); + return 0; +} + +const char *IceSession::onSearchPacketTail(const char *data, size_t len) { + if (len < 2) { + // Not enough data + return nullptr; + } + uint16_t length = (((uint8_t *)data)[0] << 8) | ((uint8_t *)data)[1]; + if (len < (size_t)(length + 2)) { + // Not enough data + return nullptr; + } + // Return the end of the RTP packet + return data + 2 + length; +} + +void IceSession::onIceTransportRecvData(const toolkit::Buffer::Ptr& buffer, const IceTransport::Pair::Ptr& pair) { + _ice_transport->processSocketData((const uint8_t *)buffer->data(), buffer->size(), pair); +} + +void IceSession::onIceTransportGatheringCandidate(const IceTransport::Pair::Ptr& pair, const CandidateInfo& candidate) { + DebugL << candidate.dumpString(); +} + +void IceSession::onIceTransportDisconnected() { + InfoL << getIdentifier(); +} + +void IceSession::onIceTransportCompleted() { + InfoL << getIdentifier(); +} + +//////////// IceSessionManager ////////////////////////// + +IceSessionManager &IceSessionManager::Instance() { + static IceSessionManager s_instance; + return s_instance; +} + +void IceSessionManager::addItem(const std::string& key, const IceSession::Ptr &ptr) { + std::lock_guard lck(_mtx); + _map[key] = ptr; +} + +IceSession::Ptr IceSessionManager::getItem(const std::string& key) { + assert(!key.empty()); + std::lock_guard lck(_mtx); + auto it = _map.find(key); + if (it == _map.end()) { + return nullptr; + } + return it->second.lock(); +} + +void IceSessionManager::removeItem(const std::string& key) { + std::lock_guard lck(_mtx); + _map.erase(key); +} + +}// namespace mediakit diff --git a/webrtc/IceSession.hpp b/webrtc/IceSession.hpp new file mode 100644 index 00000000..571b8af0 --- /dev/null +++ b/webrtc/IceSession.hpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + + +#ifndef ZLMEDIAKIT_WEBRTC_ICE_SESSION_H +#define ZLMEDIAKIT_WEBRTC_ICE_SESSION_H + +#include "Network/Session.h" +#include "IceTransport.hpp" +#include "Http/HttpRequestSplitter.h" + +namespace mediakit { + +class IceSession : public toolkit::Session, public RTC::IceTransport::Listener, public HttpRequestSplitter { +public: + using Ptr = std::shared_ptr; + using WeakPtr = std::weak_ptr; + IceSession(const toolkit::Socket::Ptr &sock); + ~IceSession() override; + + static toolkit::EventPoller::Ptr queryPoller(const toolkit::Buffer::Ptr &buffer); + + //// Session override//// + // void attachServer(const Server &server) override; + void onRecv(const toolkit::Buffer::Ptr &) override; + void onError(const toolkit::SockException &err) override; + void onManager() override; + + // ice related callbacks /// + void onIceTransportRecvData(const toolkit::Buffer::Ptr& buffer, const RTC::IceTransport::Pair::Ptr& pair) override; + void onIceTransportGatheringCandidate(const RTC::IceTransport::Pair::Ptr& pair, const RTC::CandidateInfo& candidate) override; + void onIceTransportDisconnected() override; + void onIceTransportCompleted() override; + + //// HttpRequestSplitter override //// + ssize_t onRecvHeader(const char *data, size_t len) override; + const char *onSearchPacketTail(const char *data, size_t len) override; + + void onRecv_l(const char *data, size_t len); +protected: + bool _over_tcp = false; + + RTC::IceTransport::Pair::Ptr _session_pair = nullptr; + RTC::IceServer::Ptr _ice_transport; +}; + +class IceSessionManager { +public: + static IceSessionManager &Instance(); + IceSession::Ptr getItem(const std::string& key); + void addItem(const std::string& key, const IceSession::Ptr &ptr); + void removeItem(const std::string& key); + +private: + IceSessionManager() = default; + +private: + std::mutex _mtx; + std::unordered_map> _map; +}; +}// namespace mediakit + +#endif //ZLMEDIAKIT_WEBRTC_ICE_SESSION_H diff --git a/webrtc/IceTransport.cpp b/webrtc/IceTransport.cpp new file mode 100644 index 00000000..84eecdca --- /dev/null +++ b/webrtc/IceTransport.cpp @@ -0,0 +1,2065 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. +*/ + +#include +#include +#include +#include "json/json.h" +#include "Util/onceToken.h" +#include "Network/UdpClient.h" +#include "Common/Parser.h" +#include "Common/config.h" +#include "IceTransport.hpp" +#include "WebRtcTransport.h" + +using namespace std; +using namespace toolkit; +using namespace mediakit; + +namespace RTC { + +#define RTC_FIELD "rtc." +const string kPortRange = RTC_FIELD "port_range"; +const string kMaxStunRetry = RTC_FIELD "max_stun_retry"; +static onceToken token([]() { + mINI::Instance()[kPortRange] = "49152-65535"; + mINI::Instance()[kMaxStunRetry] = 7; +}); + +static uint32_t calIceCandidatePriority(CandidateInfo::AddressType type, uint32_t component_id = 1) { + uint32_t type_preference; + switch (type) { + case CandidateInfo::AddressType::HOST: type_preference = 126; break; + case CandidateInfo::AddressType::PRFLX: type_preference = 110; break; + case CandidateInfo::AddressType::SRFLX: type_preference = 100; break; + case CandidateInfo::AddressType::RELAY: type_preference = 0; break; + default: throw std::invalid_argument(StrPrinter << "not support type :" << (uint32_t)type); + } + + uint32_t local_preference = 100; + return (type_preference << 24) + (local_preference << 8) + (256 - component_id); +} + +uint64_t calCandidatePairPriority(uint32_t G, uint32_t D) { + uint32_t min_p = (G < D) ? G : D; + uint32_t max_p = (G > D) ? G : D; + return ((uint64_t)min_p << 32) | (2 * (uint64_t)max_p) | (G > D ? 1 : 0); +} + +std::string addrToStr(const sockaddr_storage& addr) { + return StrPrinter << SockUtil::inet_ntoa((const struct sockaddr*)&addr) + << ":" << SockUtil::inet_port((const struct sockaddr*)&addr); +} + +// 检查ICE传输策略是否允许该候选者对 +static bool checkIceTransportPolicy(const IceAgent::CandidatePair& pair_info, const IceTransport::Pair::Ptr& pair) { + GET_CONFIG(int, ice_transport_policy, Rtc::kIceTransportPolicy); + + // 优先使用新的统一配置参数 + switch (static_cast(ice_transport_policy)) { + case IceTransportPolicy::kRelayOnly: + // 仅支持Relay转发:要求本地或远程是中继类型 + if (pair_info._local_candidate._type != CandidateInfo::AddressType::RELAY && + pair_info._remote_candidate._type != CandidateInfo::AddressType::RELAY) { + DebugL << "relay only policy, skip pair: " << pair_info.dumpString(); + return false; + } + break; + + case IceTransportPolicy::kP2POnly: + // 仅支持P2P直连:要求本地和远程都不是中继类型 + if (pair_info._local_candidate._type == CandidateInfo::AddressType::RELAY || + pair_info._remote_candidate._type == CandidateInfo::AddressType::RELAY) { + DebugL << "p2p only policy, skip pair: " << pair_info.dumpString(); + return false; + } + break; + + case IceTransportPolicy::kAll: + default: break; + } + + return true; +} + +//////////// IceServerInfo ////////////////////////// +void IceServerInfo::parse(const std::string &url_in) { + DebugL << url_in; + + _full_url = url_in; + auto url = url_in; + + auto schema_pos = url.find(":"); + if (schema_pos == string::npos) { + throw std::runtime_error(StrPrinter << "fail to parse schema in url: " << url_in); + } + + auto schema = url.substr(0, schema_pos); + if (strcasecmp(schema.data(), "turns") == 0) { + _schema = SchemaType::TURN; + _secure = CandidateTuple::SecureType::SECURE; + } else if (strcasecmp(schema.data(), "turn") == 0) { + _schema = SchemaType::TURN; + _secure = CandidateTuple::SecureType::NOT_SECURE; + } else if (strcasecmp(schema.data(), "stuns") == 0) { + _schema = SchemaType::STUN; + _secure = CandidateTuple::SecureType::SECURE; + } else if (strcasecmp(schema.data(), "stun") == 0) { + _schema = SchemaType::STUN; + _secure = CandidateTuple::SecureType::NOT_SECURE; + } else { + throw std::runtime_error(StrPrinter << "not support schema: " << schema); + } + + // 解析了用户名密码之后再解析?参数,防止密码中的?被判为参数分隔符 + auto pos = url.find("?"); + if (pos != string::npos) { + _param_strs = url.substr(pos + 1); + url.erase(pos); + } + _addr._port = (_secure == CandidateTuple::SecureType::NOT_SECURE) ? 3478 : 5349; + auto host = url.substr(schema_pos + 1, pos); + mediakit::splitUrl(host, _addr._host, _addr._port); + + auto params = mediakit::Parser::parseArgs(_param_strs); + if (params.find("transport") != params.end()) { + auto transport = params["transport"]; + if (strcasecmp(transport.data(), "udp") == 0) { + _transport = CandidateTuple::TransportType::UDP; + } else if (strcasecmp(transport.data(), "tcp") == 0) { + _transport = CandidateTuple::TransportType::TCP; + } else { + throw std::runtime_error(StrPrinter <<"not support transport: " << transport); + } + } else { + _transport = CandidateTuple::TransportType::UDP; + } +} + +//////////// IceTransport ////////////////////////// + +IceTransport::IceTransport(Listener* listener, std::string ufrag, std::string password, EventPoller::Ptr poller) +: _poller(std::move(poller)), _listener(listener), _ufrag(std::move(ufrag)), _password(std::move(password)) { + TraceL; + _identifier = makeRandStr(32); + _request_handlers.emplace(std::make_pair(StunPacket::Class::REQUEST, StunPacket::Method::BINDING), + std::bind(&IceTransport::handleBindingRequest, this, std::placeholders::_1, std::placeholders::_2)); +} + +void IceTransport::initialize() { + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + _retry_timer = std::make_shared(0.1f, [weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return false; + } + strong_self->checkRequestTimeouts(); + return true; + }, getPoller()); +} + +void IceTransport::sendSocketData(const Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush) { + return sendSocketData_l(buf, pair, flush); +} + +void IceTransport::sendSocketData_l(const Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush) { + // DebugL; + if (pair == nullptr) { + throw std::invalid_argument("pair should not be nullptr"); + } + + // 一次性发送一帧的rtp数据,提高网络io性能 [AUTO-TRANSLATED:fbab421e] + // Send one frame of rtp data at a time to improve network io performance + if (pair->_socket->getSock()->sockType() == SockNum::Sock_TCP) { + // 增加tcp两字节头 + uint16_t len = htons(buf->size()); + pair->_socket->SockSender::send((char *)&len, 2); + } + +#if 0 + TraceL << pair->dumpString(1) << " send " << buf->size(); + TraceL << "data: " << hexdump(buf->data(), buf->size()); +#endif + + sockaddr_storage peer_addr; + pair->get_peer_addr(peer_addr); + auto addr_len = SockUtil::get_sock_len((const struct sockaddr*)&peer_addr); + pair->_socket->sendto(buf, (struct sockaddr*)&peer_addr, addr_len); + if (flush) { + pair->_socket->flushAll(); + } +} + +bool IceTransport::processSocketData(const uint8_t* data, size_t len, const Pair::Ptr& pair) { +#if 0 + TraceL << pair->dumpString(0) << " data len: " << len; + sockaddr_storage relay_peer_addr; + if (pair->get_relayed_addr(relay_peer_addr)) { + TraceL << "data relay from peer " << addrToStr(relay_peer_addr); + } +#endif + + auto packet = StunPacket::parse(data, len); + if (!packet) { + return processChannelData(data, len, pair); + } + processStunPacket(packet, pair); + return true; +} + +void IceTransport::processStunPacket(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { +#if 0 + TraceL << "recv packet : " << packet->dumpString(true); +#endif + if ((packet->getClass() == StunPacket::Class::REQUEST) || (packet->getClass() == StunPacket::Class::INDICATION)) { + processRequest(packet, pair); + } else { + processResponse(packet, pair); + } +} + +StunPacket::Authentication IceTransport::checkRequestAuthentication(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + if (packet->getClass() == StunPacket::Class::INDICATION) { + return StunPacket::Authentication::OK; + } +#if 0 + DebugL << "_ufrag: " << _ufrag << ", _password: " << _password; +#endif + // Check authentication. + auto ret = packet->checkAuthentication(_ufrag, _password); + if (ret != StunPacket::Authentication::OK) { + sendUnauthorizedResponse(packet, pair); + } + return ret; +} + +StunPacket::Authentication IceTransport::checkResponseAuthentication(const StunPacket::Ptr& request, const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + if (!packet->hasAttribute(StunAttribute::Type::FINGERPRINT)) { + sendUnauthorizedResponse(packet, pair); + return StunPacket::Authentication::UNAUTHORIZED; + } +#if 0 + DebugL << "peer_ufrag: " << request->getPeerUfrag() << ", peer_password: " << request->getPeerPassword(); +#endif + auto ret = packet->checkAuthentication(request->getPeerUfrag(), request->getPeerPassword()); + if (ret != StunPacket::Authentication::OK) { + sendUnauthorizedResponse(packet, pair); + } + return ret; +} + +void IceTransport::processResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + + auto it = _response_handlers.find(packet->getTransactionId().data()); + if (it == _response_handlers.end()) { + WarnL << "not support stun transaction_id ignore: " << packet->dumpString(true); + return; + } + + auto request = std::move(it->second._request); + auto handle = std::move(it->second._handler); + // 收到响应后立即清理请求信息 + _response_handlers.erase(it); + + if (packet->getClass() == StunPacket::Class::ERROR_RESPONSE) { + if (StunAttrErrorCode::Code::Unauthorized == packet->getErrorCode()) { + return processUnauthorizedResponse(packet, request, pair, std::move(handle)); + } + return; + } + + if (StunPacket::Authentication::OK != checkResponseAuthentication(request, packet, pair)) { + WarnL << "checkRequestAuthentication fail: " << packet->dumpString(); + return; + } + + handle(packet, pair); +} + +#pragma pack(push, 1) +struct ChannelDataHeader { + uint16_t channel_number; + uint16_t data_length; +}; +#pragma pack(pop) + +bool IceTransport::processChannelData(const uint8_t* data, size_t len, const Pair::Ptr& pair) { + // DebugL; + + // 检查数据长度是否足够 + if (len < 4) { + WarnL << "Received data too short to be a valid STUN or ChannelData message"; + return false; + } + ChannelDataHeader header = *(reinterpret_cast(data)); + header.channel_number = ntohs(header.channel_number); + header.data_length = ntohs(header.data_length); + + // 检查是否是ChannelData消息 + // ChannelData消息的前两个字节是通道号,范围是0x4000-0x7FFF + if (header.channel_number < 0x4000 || header.channel_number > 0x7FFF) { + // WarnL << "Invalid channel number: " << header.channel_number; + return false; + } + + // 这是一个ChannelData消息; 检查数据长度是否合法 + if (len < 4 + header.data_length) { + WarnL << "ChannelData message truncated, len: " << len << ", data_length: " << header.data_length; + return false; + } + + handleChannelData(header.channel_number, (const char *)(data + 4), header.data_length, pair); + return true; +} + +void IceTransport::processUnauthorizedResponse(const StunPacket::Ptr& response, const StunPacket::Ptr& request, const Pair::Ptr& pair, MsgHandler handler) { + // TraceL; + auto attr_nonce = response->getAttribute(); + auto attr_realm = response->getAttribute(); + if (!attr_nonce || !attr_realm) { + return; + } + + request->refreshTransactionId(); + request->addAttribute(std::move(attr_nonce)); + request->addAttribute(std::move(attr_realm)); + request->setNeedMessageIntegrity(true); + sendRequest(request, pair, std::move(handler)); +} + +void IceTransport::processRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + if (StunPacket::Authentication::OK != checkRequestAuthentication(packet, pair)) { + WarnL << "checkRequestAuthentication fail: " << packet->dumpString(); + return; + } + + auto it = _request_handlers.find(std::make_pair(packet->getClass(), packet->getMethod())); + if (it == _request_handlers.end()) { + WarnL << "ignore unsupport stun "<< packet->dumpString(); + return; + } + + return (it->second)(packet, pair); +} + +void IceTransport::handleBindingRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + + auto response = packet->createSuccessResponse(); + response->setUfrag(_ufrag); + response->setPassword(_password); + + sockaddr_storage peer_addr; + if (!pair->get_relayed_addr(peer_addr)) { + pair->get_peer_addr(peer_addr); + } + + // Add XOR-MAPPED-ADDRESS. + auto attr_xor_mapped_address = std::make_shared(response->getTransactionId()); + attr_xor_mapped_address->setAddr(peer_addr); + response->addAttribute(std::move(attr_xor_mapped_address)); + + sendPacket(response, pair); +} + +void IceTransport::sendChannelData(uint16_t channel_number, const Buffer::Ptr& buffer, const Pair::Ptr& pair) { + // TraceL; + + // ChannelData不是STUN消息,需要单独实现 + // ChannelData格式:2字节Channel Number + 2字节数据长度 + 数据内容 + auto data_len = buffer->size(); + size_t total_len = 4 + data_len; + // 分配缓冲区:头部4字节 + 数据长度 + auto channel_data = toolkit::BufferRaw::create(total_len); + auto header = reinterpret_cast(channel_data->data()); + // 设置Channel Number (前两字节,网络字节序) + header->channel_number = htons(channel_number); + // 设置数据长度 (中间两字节,网络字节序) + header->data_length = htons(data_len); + // 拷贝数据 + memcpy(channel_data->data() + 4, buffer->data(), data_len); + channel_data->setSize(total_len); + +#if 0 + TraceL << pair->dumpString(1) << " send channel " << channel_number << " data " << data_len; + TraceL << "data: " << hexdump(buffer->data(), buffer->size()); +#endif + + sendSocketData(channel_data, pair); +} + +void IceTransport::sendUnauthorizedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + auto response = packet->createErrorResponse(StunAttrErrorCode::Code::Unauthorized); + sendPacket(response, pair); +} + +void IceTransport::sendErrorResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, StunAttrErrorCode::Code errorCode) { + // TraceL; + auto response = packet->createErrorResponse(errorCode); + sendPacket(response, pair); +} + +void IceTransport::sendRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair, MsgHandler handler) { + // TraceL; + _response_handlers.emplace(packet->getTransactionId().data(), RequestInfo(packet, std::move(handler), pair)); + sendPacket(packet, pair); +} + +void IceTransport::sendPacket(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { +#if 0 + TraceL << pair->dumpString(1) << " send packet: " << packet->dumpString(true); +#endif + packet->serialize(); + sendSocketData(std::static_pointer_cast(packet), pair); +} + +bool IceTransport::hasPermission(const sockaddr_storage& addr) { + auto it = _permissions.find(addr); + if (it == _permissions.end()) { + return false; + } + + // 权限有效期为5分钟 + uint64_t now = toolkit::getCurrentMillisecond(); + if (now - it->second > 5 * 60 * 1000) { + DebugL << "permissions over time, ip:" << addrToStr(addr); + _permissions.erase(it); + return false; + } + + return true; +} + +void IceTransport::addPermission(const sockaddr_storage& addr) { + _permissions[addr] = toolkit::getCurrentMillisecond(); +} + +bool IceTransport::hasChannelBind(uint16_t channel_number) { + return _channel_bindings.find(channel_number) != _channel_bindings.end(); +} + +bool IceTransport::hasChannelBind(const sockaddr_storage& addr, uint16_t& channel_number) { + for (const auto& binding : _channel_bindings) { + if (SockUtil::is_same_addr(reinterpret_cast(&binding.second), + reinterpret_cast(&addr))) { + channel_number = binding.first; + return true; + } + } + return false; +} + +void IceTransport::addChannelBind(uint16_t channel_number, const sockaddr_storage& addr) { + _channel_bindings[channel_number] = addr; + _channel_binding_times[channel_number] = toolkit::getCurrentMillisecond(); +} + +SocketHelper::Ptr IceTransport::createSocket(CandidateTuple::TransportType type, const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port) { + if (type != CandidateTuple::TransportType::UDP) { + throw std::invalid_argument("not support transport type: TCP"); + } + return createUdpSocket(peer_host, peer_port, local_ip, local_port); +} + +SocketHelper::Ptr IceTransport::createUdpSocket(const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port) { + auto socket = std::make_shared(getPoller()); + + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + auto ptr = socket.get(); + socket->setOnRecvFrom([weak_self, ptr](const Buffer::Ptr &buffer, struct sockaddr *addr, int addr_len) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + auto peer_host = SockUtil::inet_ntoa(addr); + auto peer_port = SockUtil::inet_port(addr); + auto pair = std::make_shared(ptr->shared_from_this(), std::move(peer_host), peer_port); + strong_self->_listener->onIceTransportRecvData(buffer, pair); + }); + + socket->setOnError([weak_self](const SockException &err) { + WarnL << err; + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + }); + + socket->setNetAdapter(local_ip); + socket->startConnect(peer_host, peer_port, local_port); + + return socket; +} + + +//////////// IceServer ////////////////////////// + +#ifndef UNUSED +#define UNUSED(x) ((void)(x)) +#endif +template +class PortManager : public std::enable_shared_from_this> { +public: + PortManager() = default; + + static PortManager &Instance() { + static auto instance = std::make_shared(); + return *instance; + } + + void addListenConfigReload(){ + weak_ptr weak_self = this->shared_from_this(); + static auto func = [weak_self](const string &str, int index) { + uint16_t port[] = { 49152, 65535 }; + auto strong_self = weak_self.lock(); + if (!strong_self) { + return port[index]; + } + sscanf(str.data(), "%" SCNu16 "-%" SCNu16, port, port + 1); + strong_self->setRange(port[0], port[1]); + return port[index]; + }; + + GET_CONFIG_FUNC(uint16_t, dummy_min_port, kPortRange, [](const string &str) { return func(str, 0); }); + GET_CONFIG_FUNC(uint16_t, dummy_max_port, kPortRange, [](const string &str) { return func(str, 1); }); + UNUSED(dummy_min_port); + UNUSED(dummy_max_port); + } + + std::shared_ptr getSinglePort() { + lock_guard lck(_pool_mtx); + if (_port_pool.empty()) { + return nullptr; + } + + auto pos = _port_pool.front(); + _port_pool.pop_front(); + InfoL << "got port from pool:" << pos; + + weak_ptr weak_self = this->shared_from_this(); + std::shared_ptr ret(new uint16_t(pos), [weak_self, pos](uint16_t *ptr) { + delete ptr; + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + + lock_guard lck(strong_self->_pool_mtx); + if (pos >= strong_self->_min_port && pos < strong_self->_max_port) { + InfoL << "return port:" << pos << " to pool"; + // 回收端口号 + strong_self->_port_pool.emplace_back(pos); + } else { + InfoL << "release port:" << pos << "[" << strong_self->_min_port << "-" << strong_self->_max_port << "]"; + // 端口范围修改过,该端口不在范围内了,不回收端口号 + } + }); + + return ret; + } + +private: + void setRange(uint16_t min_port, uint16_t max_port) { + assert(max_port >= min_port + 36 - 1); + lock_guard lck(_pool_mtx); + //端口范围未改变,不用处理 + if (min_port == _min_port && max_port == _max_port) { + return; + } + + InfoL << "setPortRange from [" << _min_port << "-" << _max_port << "] to [" << min_port << "-" << max_port << "]"; + + // 修改:直接使用端口值,不再除以2 + uint16_t start_pos = min_port; + uint16_t end_pos = max_port; + std::mt19937 rng(std::random_device {}()); + + //新指定的端口范围和原端口范围不交集,直接清除并重新增加 + if (max_port <= _min_port || min_port >= _max_port) { + _port_pool.clear(); + } else { + + //存在交集,先把交集范围内还未被分配的端口保留 + deque port_pool; + for (; !_port_pool.empty(); _port_pool.pop_front()) { + auto pos = _port_pool.front(); + + if (pos >= start_pos && pos < end_pos) { + port_pool.emplace_back(pos); + } + } + _port_pool.swap(port_pool); + + if (_min_port <= min_port && _max_port < max_port) { + // <_min_port|--------|********************|_max_port> + // = max_port && _min_port > min_port) { + // + // <_min_port|********************|--------|_max_port| + + start_pos = min_port; + end_pos = _min_port; + } else if (min_port < _min_port && max_port > _max_port) { + // _port_pool; +}; + +//注册端口管理监听配置重载 +onceToken PortManager_token([](){ + PortManager<0>::Instance().addListenConfigReload(); + PortManager<1>::Instance().addListenConfigReload(); +}); + +std::unordered_map _relayed_session; + +IceServer::IceServer(Listener* listener, std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller) + : IceTransport(listener, std::move(ufrag), std::move(password), std::move(poller)) { + DebugL; + + GET_CONFIG(bool, enable_turn, Rtc::kEnableTurn); + if (enable_turn) { + _request_handlers.emplace(std::make_pair(StunPacket::Class::REQUEST, StunPacket::Method::ALLOCATE), + std::bind(&IceServer::handleAllocateRequest, this, placeholders::_1, placeholders::_2)); + _request_handlers.emplace(std::make_pair(StunPacket::Class::REQUEST, StunPacket::Method::REFRESH), + std::bind(&IceServer::handleRefreshRequest, this, placeholders::_1, placeholders::_2)); + _request_handlers.emplace(std::make_pair(StunPacket::Class::REQUEST, StunPacket::Method::CREATEPERMISSION), + std::bind(&IceServer::handleCreatePermissionRequest, this, placeholders::_1, placeholders::_2)); + _request_handlers.emplace(std::make_pair(StunPacket::Class::REQUEST, StunPacket::Method::CHANNELBIND), + std::bind(&IceServer::handleChannelBindRequest, this, placeholders::_1, placeholders::_2)); + _request_handlers.emplace(std::make_pair(StunPacket::Class::INDICATION, StunPacket::Method::SEND), + std::bind(&IceServer::handleSendIndication, this, placeholders::_1, placeholders::_2)); + } + +} + +bool IceServer::processSocketData(const uint8_t* data, size_t len, const Pair::Ptr& pair) { + if (!_session_pair) { + _session_pair = pair; + } + return IceTransport::processSocketData(data, len, pair); +} + +void IceServer::processRelayPacket(const Buffer::Ptr &buffer, const Pair::Ptr& pair) { + // TraceL << pair->dumpString(0); + + sockaddr_storage peer_addr; + pair->get_peer_addr(peer_addr); + + if (!hasPermission(peer_addr)) { + WarnL << "No permission exists for peer: " << pair->get_peer_ip() << ":" << pair->get_peer_port(); + return; + } + + auto forward_pair = std::make_shared(_session_pair->_socket, pair->_socket->get_peer_ip(), pair->_socket->get_peer_port()); + uint16_t channel_number; + if (hasChannelBind(peer_addr, channel_number)) { + sendChannelData(channel_number, buffer, forward_pair); + } else { + sendDataIndication(peer_addr, buffer, forward_pair); + } +} + +void IceServer::handleAllocateRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + auto response = packet->createSuccessResponse(); + response->setUfrag(_ufrag); + response->setPassword(_password); + + // Add XOR-MAPPED-ADDRESS. + sockaddr_storage peer_addr; + pair->get_peer_addr(peer_addr); + + auto attr_xor_mapped_address = std::make_shared(response->getTransactionId()); + attr_xor_mapped_address->setAddr(peer_addr); + response->addAttribute(std::move(attr_xor_mapped_address)); + + // Add XOR-RELAYED-ADDRESS. + auto socket = allocateRelayed(pair); + sockaddr_storage relayed_addr = SockUtil::make_sockaddr(socket->get_local_ip().data(), socket->get_local_port()); + auto attr_xor_relayed_address = std::make_shared(response->getTransactionId()); + attr_xor_relayed_address->setAddr(relayed_addr); + response->addAttribute(std::move(attr_xor_relayed_address)); + + auto attr_lifetime = std::make_shared(); + attr_lifetime->setLifetime(600); + response->addAttribute(std::move(attr_lifetime)); + + sendPacket(response, pair); +} + +void IceServer::handleRefreshRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL +} + +void IceServer::handleCreatePermissionRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL + + // 检查XOR-PEER-ADDRESS属性是否存在 + auto peer_addr = packet->getAttribute(); + if (!peer_addr) { + WarnL << "CreatePermission request missing XOR-PEER-ADDRESS attribute"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + return; + } + + addPermission(peer_addr->getAddr()); + + auto response = packet->createSuccessResponse(); + response->setUfrag(_ufrag); + response->setPassword(_password); + sendPacket(response, pair); +} + +void IceServer::handleChannelBindRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL + + // 检查必要的属性 + auto channel_number = packet->getAttribute(); + auto peer_addr = packet->getAttribute(); + if (!channel_number || !peer_addr) { + WarnL << "ChannelBind request missing required attributes"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + return; + } + + // 验证通道号是否在有效范围内 (0x4000-0x7FFF) + uint16_t number = channel_number->getChannelNumber(); + if (number < 0x4000 || number > 0x7FFF) { + WarnL << "Invalid channel number: " << number; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + return; + } + + // 检查是否有对应peer地址的权限 + auto addr = peer_addr->getAddr(); + if (!hasPermission(addr)) { + WarnL << "No permission exists for peer address"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::Forbidden); + return; + } + + // 添加或更新通道绑定 + addChannelBind(number, addr); + + auto response = packet->createSuccessResponse(); + response->setUfrag(_ufrag); + response->setPassword(_password); + sendPacket(response, pair); +} + +void IceServer::handleSendIndication(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL + + // 检查必要的属性 + auto peer_addr = packet->getAttribute(); + auto data = packet->getAttribute(); + + if (!peer_addr || !data) { + WarnL << "Send indication missing required attributes"; + return; + } + + // 检查是否有对应peer地址的权限 + auto addr = peer_addr->getAddr(); + if (!hasPermission(addr)) { + WarnL << "No permission exists for peer address"; + return; + } + + auto buffer = data->getData(); + auto send_buffer = BufferRaw::create(buffer.size()); + send_buffer->assign(buffer.data(), buffer.size()); + return relayBackingData(send_buffer, pair, addr); +} + +void IceServer::handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) { + // TraceL << "Received ChannelData message, channel number: " << channel_number; + + // 查找该通道号对应的目标地址 + auto it = _channel_bindings.find(channel_number); + if (it == _channel_bindings.end()) { + WarnL << "No binding found for channel number: " << channel_number; + return; + } + + // 获取目标地址 + sockaddr_storage peer_addr = it->second; + + // 创建一个新的缓冲区用于转发 + auto buffer = BufferRaw::create(len); + buffer->assign(data, len); + + // 转发数据到目标地址 + relayBackingData(buffer, pair, peer_addr); +} + +void IceServer::sendUnauthorizedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL + + if (packet->getMethod() == StunPacket::Method::ALLOCATE) { + auto response = packet->createErrorResponse(StunAttrErrorCode::Code::Unauthorized); + auto attr_nonce = std::make_shared(); + auto nonce = makeRandStr(80); + _nonce_list.push_back(nonce); + attr_nonce->setNonce(nonce); + response->addAttribute(std::move(attr_nonce)); + + auto attr_realm = std::make_shared(); + attr_realm->setRealm("ZLM"); //TODO: use config.ini + response->addAttribute(std::move(attr_realm)); + sendPacket(response, pair); + return; + } + + IceTransport::sendUnauthorizedResponse(packet, pair); +} + +StunPacket::Authentication IceServer::checkRequestAuthentication(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL + + //ICE SERVER 不对BINDGING请求校验 + if (packet->getMethod() == StunPacket::Method::BINDING) { + return StunPacket::Authentication::OK; + } + + return IceTransport::checkRequestAuthentication(packet, pair); +} + +void IceServer::sendDataIndication(const sockaddr_storage& peer_addr, const Buffer::Ptr& buffer, const Pair::Ptr& pair) { + // TraceL + + auto packet = std::make_shared(); + + auto attr_peer_address = std::make_shared(packet->getTransactionId()); + attr_peer_address->setAddr(peer_addr); + packet->addAttribute(std::move(attr_peer_address)); + + auto attr_data = std::make_shared(); + attr_data->setData((const char *)buffer->data(), buffer->size()); + packet->addAttribute(std::move(attr_data)); + + sendPacket(packet, pair); +#if 0 + TraceL << pair->dumpString(1) << " Forward UDP data as DataIndication, size: " << buffer->size(); +#endif +} + +SocketHelper::Ptr IceServer::allocateRelayed(const Pair::Ptr& pair) { + // DebugL; + + // only support udp + auto port = PortManager<0>::Instance().getSinglePort(); + std::string local_ip; + + GET_CONFIG_FUNC(std::vector, interfaces, Rtc::kInterfaces, [](string str) { + std::vector ret; + if (str.length()) { + ret = split(str, ","); + } + translateIPFromEnv(ret); + return ret; + }); + + //如果指定了对外的网卡,使用第一个对外网卡的ip + if (!interfaces.empty()) { + auto machine_interfaces = SockUtil::getInterfaceList(); + for (auto& obj : machine_interfaces) { + std::string& interface_ip = obj["ip"]; + if (toolkit::start_with(obj["name"], "lo") || interface_ip == "0.0.0.0" || interface_ip == "::") { + DebugL << "skip interace: " << obj["name"] << " " << interface_ip; + continue; + } + if (obj["name"] == interfaces.front()) { + DebugL << "use interace: " << obj["name"] << " " << interface_ip; + local_ip = interface_ip; + break; + } + } + } + + if (interfaces.empty() || local_ip.empty()){ + GET_CONFIG_FUNC(std::vector, extern_ips, Rtc::kExternIP, [](string str) { + std::vector ret; + if (str.length()) { + ret = split(str, ","); + } + translateIPFromEnv(ret); + return ret; + }); + + if (!extern_ips.empty()) { + local_ip = extern_ips.front(); + } else { + local_ip = SockUtil::get_local_ip(); + } + } + + auto socket = createRelayedUdpSocket(pair->get_peer_ip(), pair->get_peer_port(), local_ip, *port); + auto relayed_pair = std::make_shared(socket); + auto peer_addr = SockUtil::make_sockaddr(pair->get_peer_ip().data(), pair->get_peer_port()); + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + _relayed_pairs.emplace(peer_addr, std::make_pair(port, relayed_pair)); + _relayed_session.emplace(peer_addr, weak_self); + + InfoL << "Alloc relayed pair: " << relayed_pair->get_local_ip() << ":" << relayed_pair->get_local_port() + << " for peer pair: " << pair->get_peer_ip() << ":" << pair->get_peer_port(); + return socket; +} + +void IceServer::relayForwordingData(const toolkit::Buffer::Ptr& buffer, const sockaddr_storage& peer_addr) { + TraceL; + getPoller()->async([=]() { + auto it = _relayed_pairs.find(peer_addr); + if (it == _relayed_pairs.end()) { +#if 0 + //不是当前对象的转发,交给其他对象转发 + auto forword_it = _relayed_session.find(peer_addr); + if (forword_it == _relayed_session.end()) { + WarnL << "not relayed addr for peer addr: " << addrToStr(peer_addr); + } + + auto forword_session = forword_it->second.lock(); + if (!forword_session) { + WarnL << "forword session for peer addr " << addrToStr(peer_addr) << " is release"; + return; + } + + if (getIdentifier() == forword_session->getIdentifier()) { + //找到的会话就是当前会话,忽略 + return; + } + forword_session->relayForwordingData(buffer, peer_addr); + return; +#else + WarnL << "not relayed addr for peer addr: " << addrToStr(peer_addr); +#endif + } + + sendSocketData(buffer, it->second.second); +#if 0 + TraceL << "Forwarded ChannelData to peer: " << addrToStr(peer_addr); +#endif + }); +} + +void IceServer::relayBackingData(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair, const sockaddr_storage& peer_addr) { + // TraceL; + + sockaddr_storage addr; + pair->get_peer_addr(addr); + + auto it = _relayed_pairs.find(addr); + if (it == _relayed_pairs.end()) { + WarnL << "not relayed addr for peer addr: " << addrToStr(addr); + return; + } + + auto forward_pair = std::make_shared(it->second.second->_socket, + SockUtil::inet_ntoa((const struct sockaddr *)&peer_addr), SockUtil::inet_port((const struct sockaddr *)&peer_addr)); + + sendSocketData(buffer, forward_pair); +#if 0 + DebugL << "relay backing " << forward_pair->dumpString(1); +#endif +} + +SocketHelper::Ptr IceServer::createRelayedUdpSocket(const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port) { + auto socket = std::make_shared(getPoller()); + + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + auto ptr = socket.get(); + socket->setOnRecvFrom([weak_self, ptr](const Buffer::Ptr &buffer, struct sockaddr *addr, int addr_len) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + auto peer_host = SockUtil::inet_ntoa(addr); + auto peer_port = SockUtil::inet_port(addr); + auto pair = std::make_shared(ptr->shared_from_this(), std::move(peer_host), peer_port); + strong_self->processRelayPacket(buffer, pair); + }); + + socket->setOnError([weak_self](const SockException &err) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + }); + + socket->setNetAdapter(local_ip); + socket->startConnect(peer_host, peer_port, local_port); + + return socket; +} + +//////////// IceAgent ////////////////////////// + +IceAgent::IceAgent(Listener* listener, Implementation implementation, Role role, std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller) +: IceTransport(listener, std::move(ufrag), std::move(password), std::move(poller)), _implementation(implementation) ,_role(role) { + DebugL; + _tiebreaker = makeRandNum(); + // 创建定时器,每分钟检查一次权限和通道绑定是否需要刷新 + _refresh_timer = std::make_shared(60.0f, [this]() { + refreshPermissions(); + refreshChannelBindings(); + return true; + }, getPoller()); +} + +void IceAgent::gatheringCandidate(const CandidateTuple::Ptr& candidate_tuple, bool gathering_rflx, bool gathering_relay) { + // TraceL; + + auto interfaces = SockUtil::getInterfaceList(); + for (auto obj : interfaces) { + std::string local_ip = obj["ip"]; + if (toolkit::start_with(obj["name"], "lo") || local_ip == "0.0.0.0" || local_ip == "::") { + DebugL << "skip interace: " << obj["name"] << " " << local_ip; + continue; + } + + try { + CandidateInfo candidate; + candidate._type = CandidateInfo::AddressType::HOST; + candidate._ufrag = getUfrag(); + candidate._pwd = getPassword(); + candidate._transport = candidate_tuple->_transport; + + auto socket = createSocket(candidate_tuple->_transport, candidate_tuple->_addr._host, candidate_tuple->_addr._port, local_ip); + _socket_candidate_manager.addHostSocket(socket); + candidate._addr._host = candidate._base_addr._host = local_ip; + candidate._addr._port = candidate._base_addr._port = socket->get_local_port(); + + TraceL << "gathering local candidate " << candidate.dumpString() << " from stun server " << candidate_tuple->_addr.dumpString(); + + auto pair = std::make_shared(std::move(socket)); + onGatheringCandidate(pair, candidate); + if (gathering_rflx) { + gatheringSrflxCandidate(pair); + } + + if (gathering_relay) { + //TODO: 代优化relay_socket 复用host socket当前SocketCandidateManager数据结构不支持 + auto relay_socket = createSocket(candidate_tuple->_transport, candidate_tuple->_addr._host, candidate_tuple->_addr._port, local_ip); + _socket_candidate_manager.addRelaySocket(relay_socket); + gatheringRelayCandidate(std::make_shared(std::move(relay_socket))); + } + } catch (std::exception &ex) { + WarnL << ex.what(); + } + } +} + +void IceAgent::connectivityCheck(CandidateInfo& candidate) { + TraceL << candidate.dumpString(); + setState(IceAgent::State::Running); + auto ret = _remote_candidates.emplace(candidate); + if (ret.second) { + bool udp = candidate._transport == CandidateTuple::TransportType::UDP; + for (auto& socket : _socket_candidate_manager._host_sockets) { + if (udp != (socket->getSock()->sockType() == SockNum::Sock_UDP)) { + continue; + } + auto pair = std::make_shared(socket, candidate._addr._host, candidate._addr._port); + addToChecklist(pair, candidate); + } + + if (_socket_candidate_manager._has_relayed_candidate) { + localRelayedConnectivityCheck(candidate); + } + } + +} + +void IceAgent::localRelayedConnectivityCheck(CandidateInfo& candidate) { + TraceL << candidate.dumpString(); + for (auto &socket: _socket_candidate_manager._relay_sockets) { + auto local_relay_pair = std::make_shared(socket, _ice_server->_addr._host, _ice_server->_addr._port); + auto peer_addr = SockUtil::make_sockaddr(candidate._addr._host.data(), candidate._addr._port); + sendCreatePermissionRequest(local_relay_pair, peer_addr); + + local_relay_pair->_relayed_addr = std::make_shared(peer_addr); + addToChecklist(local_relay_pair, candidate); + } +} + +void IceAgent::nominated(const Pair::Ptr& pair, CandidateTuple& candidate) { + // TraceL; + auto handler = std::bind(&IceAgent::handleNominatedResponse, this, placeholders::_1, placeholders::_2, candidate); + sendBindRequest(pair, candidate, true, std::move(handler)); +} + +void IceAgent::sendSendIndication(const sockaddr_storage& peer_addr, const Buffer::Ptr& buffer, const Pair::Ptr& pair) { + // TraceL; + auto packet = std::make_shared(); + + auto attr_peer_address = std::make_shared(packet->getTransactionId()); + attr_peer_address->setAddr(peer_addr); + packet->addAttribute(std::move(attr_peer_address)); + + auto attr_data = std::make_shared(); + attr_data->setData(buffer->data(), buffer->size()); + packet->addAttribute(std::move(attr_data)); + + sendPacket(packet, pair); +} + +void IceAgent::gatheringSrflxCandidate(const Pair::Ptr& pair) { + // TraceL; + auto handle = std::bind(&IceAgent::handleGatheringCandidateResponse, this, placeholders::_1, placeholders::_2); + sendBindRequest(pair, std::move(handle)); +} + +void IceAgent::gatheringRelayCandidate(const Pair::Ptr &pair) { + // TraceL; + sendAllocateRequest(pair); +} + +void IceAgent::connectivityCheck(const Pair::Ptr &pair, CandidateTuple& candidate) { + // TraceL; + auto handler = std::bind(&IceAgent::handleConnectivityCheckResponse, this, placeholders::_1, placeholders::_2, candidate); + sendBindRequest(pair, candidate, false, std::move(handler)); +} + +void IceAgent::tryTriggerredCheck(const Pair::Ptr& pair) { + // DebugL; + //FIXME 暂不实现,因为当前实现基本收到candidate就会发起check +} + +void IceAgent::sendBindRequest(const Pair::Ptr& pair, MsgHandler handler) { + // TraceL; + auto packet = std::make_shared(); + packet->setUfrag(_ufrag); + packet->setPassword(_password); + packet->setPeerUfrag(_ice_server->_ufrag); + packet->setPeerPassword(_ice_server->_pwd); + + packet->setNeedFingerprint(false); + packet->setNeedMessageIntegrity(false); + sendRequest(packet, pair, std::move(handler)); +} + +void IceAgent::sendBindRequest(const Pair::Ptr& pair, CandidateTuple& candidate, bool use_candidate, MsgHandler handler) { + // TraceL; + auto packet = std::make_shared(); + packet->setUfrag(_ufrag); + packet->setPassword(_password); + packet->setPeerUfrag(candidate._ufrag); + packet->setPeerPassword(candidate._pwd); + + auto attr_username = std::make_shared(); + attr_username->setUsername(candidate._ufrag + ":" + _ufrag); + packet->addAttribute(std::move(attr_username)); + + if (getRole() == Role::Controlling) { + auto attr_icecontrolling = std::make_shared(); + attr_icecontrolling->setTiebreaker(_tiebreaker); + packet->addAttribute(std::move(attr_icecontrolling)); + } else { + auto attr_icecontrolled = std::make_shared(); + attr_icecontrolled->setTiebreaker(_tiebreaker); + packet->addAttribute(std::move(attr_icecontrolled)); + } + + if (use_candidate) { + auto attr_use_candidate = std::make_shared(); + packet->addAttribute(std::move(attr_use_candidate)); + } + + if (candidate._priority != 0) { + auto attr_priority = std::make_shared(); + attr_priority->setPriority(candidate._priority); + packet->addAttribute(std::move(attr_priority)); + } + + sendRequest(packet, pair, std::move(handler)); +} + +void IceAgent::sendAllocateRequest(const Pair::Ptr& pair) { + // TraceL; + auto packet = std::make_shared(); + packet->setNeedMessageIntegrity(false); + packet->setUfrag(_ufrag); + packet->setPassword(_password); + packet->setPeerUfrag(_ice_server->_ufrag); + packet->setPeerPassword(_ice_server->_pwd); + + auto attr_username = std::make_shared(); + attr_username->setUsername(_ice_server->_ufrag); + packet->addAttribute(std::move(attr_username)); + + auto attr_requested_transport = std::make_shared(); + attr_requested_transport->setProtocol(StunAttrRequestedTransport::Protocol::UDP); + packet->addAttribute(std::move(attr_requested_transport)); + + auto handler = std::bind(&IceAgent::handleAllocateResponse, this, placeholders::_1, placeholders::_2); + sendRequest(packet, pair, std::move(handler)); +} + +void IceAgent::sendCreatePermissionRequest(const Pair::Ptr& pair, const sockaddr_storage& peer_addr) { + // TraceL; + + addPermission(peer_addr); + + auto packet = std::make_shared(); + packet->setUfrag(_ufrag); + packet->setPassword(_password); + packet->setPeerUfrag(_ice_server->_ufrag); + packet->setPeerPassword(_ice_server->_pwd); + + auto attr_username = std::make_shared(); + attr_username->setUsername(_ice_server->_ufrag); + packet->addAttribute(std::move(attr_username)); + auto attr_peer_address = std::make_shared(packet->getTransactionId()); + attr_peer_address->setAddr(peer_addr); + packet->addAttribute(std::move(attr_peer_address)); + + auto handler = std::bind(&IceAgent::handleCreatePermissionResponse, this, placeholders::_1, placeholders::_2, peer_addr); + sendRequest(packet, pair, std::move(handler)); +} + +void IceAgent::sendChannelBindRequest(const Pair::Ptr& pair, uint16_t channel_number, const sockaddr_storage& peer_addr) { + // TraceL; + auto packet = std::make_shared(); + packet->setUfrag(_ufrag); + packet->setPassword(_password); + packet->setPeerUfrag(_ice_server->_ufrag); + packet->setPeerPassword(_ice_server->_pwd); + + auto attr_username = std::make_shared(); + attr_username->setUsername(_ice_server->_ufrag); + packet->addAttribute(std::move(attr_username)); + + auto attr_channel_number = std::make_shared(); + attr_channel_number->setChannelNumber(channel_number); + packet->addAttribute(std::move(attr_channel_number)); + + auto attr_peer_address = std::make_shared(packet->getTransactionId()); + attr_peer_address->setAddr(peer_addr); + packet->addAttribute(std::move(attr_peer_address)); + + auto handler = std::bind(&IceAgent::handleChannelBindResponse, this, placeholders::_1, placeholders::_2, channel_number, peer_addr); + sendRequest(packet, pair, std::move(handler)); +} + +void IceAgent::processRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + static toolkit::onceToken token([this]() { + _request_handlers.emplace(std::make_pair(StunPacket::Class::INDICATION, StunPacket::Method::DATA), std::bind(&IceAgent::handleDataIndication, this, placeholders::_1, placeholders::_2)); + }); + return IceTransport::processRequest(packet, pair); +} + +void IceAgent::handleBindingRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + auto controlling = packet->getAttribute(); + auto controlled = packet->getAttribute(); + auto priority = packet->getAttribute(); + + //角色冲突 + if (controlling && getRole() == Role::Controlling) { + if (controlling->getTiebreaker() > _tiebreaker) { + setRole(Role::Controlled); + InfoL << "role conflict, election fail, change role to controlled"; + } else { + InfoL << "role conflict, election success in controlling, send error"; + auto response = packet->createErrorResponse(StunAttrErrorCode::Code::RoleConflict); + response->setUfrag(_ufrag); + response->setPassword(_password); + auto attr_icecontrolling = std::make_shared(); + attr_icecontrolling->setTiebreaker(_tiebreaker); + response->addAttribute(std::move(attr_icecontrolling)); + return sendPacket(response, pair); + } + } else if (controlled && getRole() == Role::Controlled) { + if (controlled->getTiebreaker() > _tiebreaker) { + setRole(Role::Controlling); + InfoL << "rule conflict, election fail, change role to controlling"; + } else { + InfoL << "rule conflict, election success in controlled, send error"; + auto response = packet->createErrorResponse(StunAttrErrorCode::Code::RoleConflict); + response->setUfrag(_ufrag); + response->setPassword(_password); + auto attr_icecontrolled = std::make_shared(); + attr_icecontrolled->setTiebreaker(_tiebreaker); + response->addAttribute(std::move(attr_icecontrolled)); + return sendPacket(response, pair); + } + } + + auto response = packet->createSuccessResponse(); + response->setUfrag(_ufrag); + response->setPassword(_password); + + sockaddr_storage peer_addr; + if (!pair->get_relayed_addr(peer_addr)) { + pair->get_peer_addr(peer_addr); + } + + // Add XOR-MAPPED-ADDRESS. + auto attr_xor_mapped_address = std::make_shared(response->getTransactionId()); + attr_xor_mapped_address->setAddr(peer_addr); + response->addAttribute(std::move(attr_xor_mapped_address)); + + if (packet->hasAttribute(StunAttribute::Type::USE_CANDIDATE)) { + if (getRole() == Role::Controlled) { + _nominated_response = response; + onCompleted(pair); + } + } else { + sendPacket(response, pair); + tryTriggerredCheck(pair); + } + +} + +void IceAgent::handleGatheringCandidateResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + + if (StunPacket::Class::SUCCESS_RESPONSE != packet->getClass()) { + WarnL << "fail, get response: " << packet->dumpString(); + return; + } + + auto srflx = packet->getAttribute(); + if (!srflx) { + WarnL << "Binding request missing XOR_MAPPED_ADDRESS attribute"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + } + + CandidateInfo candidate; + candidate._type = CandidateInfo::AddressType::SRFLX; + candidate._addr._host = srflx->getIp(); + candidate._addr._port = srflx->getPort(); + candidate._base_addr._host = pair->get_local_ip(); + candidate._base_addr._port = pair->get_local_port(); + candidate._ufrag = getUfrag(); + candidate._pwd = getPassword(); + onGatheringCandidate(pair, candidate); +} + +void IceAgent::handleConnectivityCheckResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, CandidateTuple& candidate) { + // TraceL; + + if (StunPacket::Class::SUCCESS_RESPONSE != packet->getClass()) { + WarnL << "fail, get response: " << packet->dumpString(); + if (packet->getErrorCode() == StunAttrErrorCode::Code::RoleConflict) { + InfoL << "process Role Conflict"; + + auto controlling = packet->getAttribute(); + auto controlled = packet->getAttribute(); + //角色冲突 + if (controlling && getRole() == Role::Controlling) { + if (controlling->getTiebreaker() > _tiebreaker) { + InfoL << "rule conflict, election fail, change role to controlled"; + setRole(Role::Controlled); + } else { + InfoL << "rule conflict, election success in controlling, skip"; + return; + } + } else if (controlled && getRole() == Role::Controlled) { + if (controlled->getTiebreaker() > _tiebreaker) { + InfoL << "rule conflict, election fail, change role to controlling"; + setRole(Role::Controlling); + } else { + InfoL << "rule conflict, election success in controlled, skip"; + return; + } + } + connectivityCheck(pair, candidate); + } + return; + } + + auto srflx = packet->getAttribute(); + if (!srflx) { + WarnL << "Binding request missing XOR_MAPPED_ADDRESS attribute"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + } + + if (!pair->_relayed_addr) { + //relay的消息不添加PRFLX candidaite + CandidateInfo preflx_candidate; + preflx_candidate._type = CandidateInfo::AddressType::PRFLX; + preflx_candidate._addr._host = srflx->getIp(); + preflx_candidate._addr._port = srflx->getPort(); + preflx_candidate._base_addr._host = pair->get_local_ip(); + preflx_candidate._base_addr._port = pair->get_local_port(); + preflx_candidate._ufrag = getUfrag(); + preflx_candidate._pwd = getPassword(); + onGatheringCandidate(pair, preflx_candidate); + } + + DebugL << "get candidate type preflx: " << srflx->getIp() << ":" << srflx->getPort(); + onConnected(pair); +} + +void IceAgent::handleNominatedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, CandidateTuple& candidate) { + // TraceL; + + if (StunPacket::Class::SUCCESS_RESPONSE != packet->getClass()) { + WarnL << "fail, get response: " << packet->dumpString(); + if (packet->getErrorCode() == StunAttrErrorCode::Code::RoleConflict) { + //角色冲突 + InfoL << "process Role Conflict"; + auto controlling = packet->getAttribute(); + auto controlled = packet->getAttribute(); + if (controlling && getRole() == Role::Controlling) { + if (controlling->getTiebreaker() > _tiebreaker) { + InfoL << "rule conflict, election fail, change role to controlled"; + setRole(Role::Controlled); + return; + } else { + InfoL << "rule conflict, election success in controlling, skip"; + return; + } + } else if (controlled && getRole() == Role::Controlled) { + if (controlled->getTiebreaker() > _tiebreaker) { + InfoL << "rule conflict, election fail, change role to controlling"; + setRole(Role::Controlling); + } else { + InfoL << "rule conflict, election success in controlled, skip"; + return; + } + } + nominated(pair, candidate); + return; + } + } + + auto srflx = packet->getAttribute(); + if (!srflx) { + WarnL << "Binding request missing XOR_MAPPED_ADDRESS attribute"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + } + + onCompleted(pair); +} + +void IceAgent::handleAllocateResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + + if (StunPacket::Class::SUCCESS_RESPONSE != packet->getClass()) { + WarnL << "fail, get response: " << packet->dumpString() << ", errorCode: " << (uint16_t)packet->getErrorCode(); + if (packet->getErrorCode() == StunAttrErrorCode::Code::AllocationQuotaReached) { + InfoL << "use stun retry"; + } + return; + } + + auto srflx = packet->getAttribute(); + if (!srflx) { + WarnL << "Binding request missing XOR_MAPPED_ADDRESS attribute"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + } + +#if 0 + CandidateInfo candidate; + candidate._type = CandidateInfo::AddressType::SRFLX; + candidate._addr._host = srflx->getAddrString(); + candidate._addr._port = srflx->getPort(); + candidate._base_addr._host = pair->get_local_ip(); + candidate._base_addr._port = pair->get_local_port(); + candidate._ufrag = getUfrag(); + candidate._pwd = getPassword(); + onGatheringCandidate(pair, candidate); +#endif + + auto relay = packet->getAttribute(); + if (!relay) { + WarnL << "Binding request missing XOR_RELAYED_ADDRESS attribute"; + sendErrorResponse(packet, pair, StunAttrErrorCode::Code::BadRequest); + } + + CandidateInfo candidate; + candidate._type = CandidateInfo::AddressType::RELAY; + candidate._addr._host = relay->getIp(); + candidate._addr._port = relay->getPort(); + candidate._base_addr._host = relay->getIp(); + candidate._base_addr._port = relay->getPort(); + candidate._ufrag = getUfrag(); + candidate._pwd = getPassword(); + + TraceL << "get local candidate type " << candidate.dumpString() + << ", by srflx addr " << srflx->getIp() << " : " << srflx->getPort() + << ", by host addr " << pair->get_local_ip() << " : " << pair->get_local_port(); + onGatheringCandidate(pair, candidate); +} + +void IceAgent::handleCreatePermissionResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, const sockaddr_storage& peer_addr) { + // TraceL; + + if (StunPacket::Class::SUCCESS_RESPONSE != packet->getClass()) { + WarnL << "CreatePermission failed, response: " << packet->dumpString(); + return; + } + + // TraceL << "CreatePermission successfully"; + + static uint16_t next_channel = 0x4000; // 有效范围是 0x4000-0x7FFF + uint16_t channel_number = next_channel++; + if (next_channel > 0x7FFF) { + next_channel = 0x4000; // 循环使用通道号 + } + + sendChannelBindRequest(pair, channel_number, peer_addr); +} + +void IceAgent::handleChannelBindResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, uint16_t channel_number, const sockaddr_storage& peer_addr) { + // TraceL; + + if (StunPacket::Class::SUCCESS_RESPONSE != packet->getClass()) { + WarnL << "ChannelBind failed, response: " << packet->dumpString(); + return; + } + + InfoL << "ChannelBind success, channel_number=" << channel_number + << ", peer_addr=" << addrToStr(peer_addr) + << ", pair: " << pair->dumpString(2); + + addChannelBind(channel_number, peer_addr); +} + +void IceAgent::handleDataIndication(const StunPacket::Ptr& packet, const Pair::Ptr& pair) { + // TraceL; + + // 检查必要的属性 + auto peer_addr = packet->getAttribute(); + auto data = packet->getAttribute(); + + if (!peer_addr || !data) { + WarnL << "Data indication missing required attributes"; + return; + } + + // 获取对端地址 + auto addr = peer_addr->getAddr(); + + // 检查是否有对应peer地址的权限 + if (!hasPermission(addr)) { + WarnL << "No permission exists for peer address"; + return; + } + + // 获取数据 + auto buffer = data->getData(); + + // 创建一个新的缓冲区 + auto recv_buffer = BufferRaw::create(buffer.size()); + recv_buffer->assign(buffer.data(), buffer.size()); + + DebugL << "Received Data indication from peer: " << addrToStr(addr) << ", size: " << buffer.size(); + + // 通知上层收到数据 + pair->_relayed_addr = std::make_shared(); + memcpy(pair->_relayed_addr.get(), &addr, sizeof(addr)); + _listener->onIceTransportRecvData(recv_buffer, pair); +} + +void IceAgent::handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) { + // TraceL << "Received ChannelData message, channel number: " << channel_number; + + // 查找该通道号对应的目标地址 + auto it = _channel_bindings.find(channel_number); + if (it == _channel_bindings.end()) { + WarnL << "No binding found for channel number: " << channel_number; + return; + } + + // 获取目标地址 + sockaddr_storage addr = it->second; + + // 创建一个新的缓冲区用于转发 + auto buffer = BufferRaw::create(len); + buffer->assign(data, len); + + auto channel_pair = std::make_shared(*pair); + channel_pair->_relayed_addr = std::make_shared(); + memcpy(channel_pair->_relayed_addr.get(), &addr, sizeof(addr)); + _listener->onIceTransportRecvData(buffer, channel_pair); +} + +void IceAgent::onGatheringCandidate(const Pair::Ptr& pair, CandidateInfo& candidate) { + candidate._priority = calIceCandidatePriority(candidate._type); + InfoL << "got candidate " << candidate.dumpString(); + + // 使用_socket_candidate_manager替代_local_candidates进行5元组重复检查 + if (!_socket_candidate_manager.addMapping(pair->_socket, candidate)) { + InfoL << "has same 5 tuple, skip"; + return; + } + + _listener->onIceTransportGatheringCandidate(pair, candidate); + + //如果是REALY,当前的所有PEER Candidate进行CreatePermission + if (candidate._type == CandidateInfo::AddressType::RELAY) { + _socket_candidate_manager._has_relayed_candidate = true; + for (auto remote_candidate : _remote_candidates) { + localRelayedConnectivityCheck(remote_candidate); + } + } +} + +void IceAgent::onConnected(const IceTransport::Pair::Ptr& pair) { + DebugL << "get connectivity check pair: " << pair->dumpString(2); + + if (getState() != State::Running) { + InfoL << "ice state: "<< stateToString(getState()) << " is not running, skip"; + return; + } + + TraceL << "checklist size: " << _check_list.size(); + //判断ConnectivityCheck的pair是否在checklist中,存在的话加入到validlist + for (auto &candidate_pair : _check_list) { + auto &pair_it = candidate_pair->_local_pair; + auto &remote_candidate = candidate_pair->_remote_candidate; + auto &state = candidate_pair->_state; + + TraceL << "check pair " << candidate_pair->dumpString() << ", pair info: " << pair_it->dumpString(2); + + //即使是新的Peer 反射地址,也已经添加到_checklist中了 + //所以肯定会在_checklist中找到匹配项 + if (!Pair::is_same(pair_it.get(), pair.get())) { + continue; + } + + if (state == CandidateInfo::State::Frozen || state == CandidateInfo::State::Waiting) { + continue; + } + + state = CandidateInfo::State::Succeeded; + + // 检查ICE传输策略 + if (!checkIceTransportPolicy(*candidate_pair, pair)) { + return; + } + + InfoL << "push " << candidate_pair->dumpString() << " to valid_list"; + + // 直接将候选者对添加到valid_list + _valid_list.push_back(candidate_pair); + + if (getRole() == Role::Controlling) { + if (getState() != IceAgent::State::Nominated && getState() != IceAgent::State::Completed) { + //TODO:need process priority + setState(IceAgent::State::Nominated); + nominated(pair, remote_candidate); + } + } + } + + if (getRole() == Role::Controlled && _nominated_pair) { + onCompleted(_nominated_pair); + } +} + +void IceAgent::onCompleted(const IceTransport::Pair::Ptr& pair) { + // TraceL; + bool found_in_valid_list = false; + if (getImplementation() == Implementation::Full) { + for (auto &candidate_pair : _valid_list) { + auto &pair_it = candidate_pair->_local_pair; + if (Pair::is_same(pair_it.get(), pair.get())) { + candidate_pair->_nominated = true; + _select_candidate_pair = candidate_pair; + InfoL << "select pair: " << candidate_pair->dumpString(); + found_in_valid_list = true; + break; + } + } + + if (!found_in_valid_list) { + InfoL << "not found peer pair: ip: " << pair->get_peer_ip() << ", port: " << pair->get_peer_port() << "in valid_list, record first"; + //提名的candidate 未在_valid_list 中找到.先记录 + _nominated_pair = pair; + } + } else { + //Lite 模式,不做candidate校验逻辑 + found_in_valid_list = true; + } + + if (found_in_valid_list) { + + if (setSelectedPair(pair)) { + + if (getState() != IceAgent::State::Completed) { + setState(IceAgent::State::Completed); + } + + _listener->onIceTransportCompleted(); + _nominated_pair = nullptr; + } + + if (_nominated_response) { + sendPacket(_nominated_response, pair); + _nominated_response = nullptr; + } + } +} + +void IceAgent::refreshPermissions() { + if (!_ice_server || _ice_server->_schema != IceServerInfo::SchemaType::TURN) { + return; + } + + uint64_t now = toolkit::getCurrentMillisecond(); + + // 遍历所有权限,删除过期的权限 + for (auto it = _permissions.begin(); it != _permissions.end();) { + if (now - it->second > 5 * 60 * 1000) { + it = _permissions.erase(it); + } else { + ++it; + } + } + + // 对于 + for (auto& permission : _permissions) { + if (now - permission.second > 4 * 60 * 1000) { + // 创建一个新的权限请求 + sockaddr_storage addr = permission.first; + for (auto& socket : _socket_candidate_manager._relay_sockets) { + auto pair = std::make_shared(socket); + sendCreatePermissionRequest(pair, addr); + break; // 只需要使用一个本地候选项发送请求 + } + } + } +} + +void IceAgent::refreshChannelBindings() { + if (!_ice_server || _ice_server->_schema != IceServerInfo::SchemaType::TURN) { + return; + } + uint64_t now = toolkit::getCurrentMillisecond(); + + // 遍历所有通道绑定,删除过期的绑定 + for (auto it = _channel_binding_times.begin(); it != _channel_binding_times.end();) { + if (now - it->second > 10 * 60 * 1000) { // 通道绑定有效期为10分钟 + _channel_bindings.erase(it->first); + it = _channel_binding_times.erase(it); + } else { + ++it; + } + } + + // 对于即将过期的通道绑定(例如还有2分钟过期),刷新它们 + for (auto& binding_time : _channel_binding_times) { + if (now - binding_time.second > 8 * 60 * 1000) { + uint16_t channel_number = binding_time.first; + auto it = _channel_bindings.find(channel_number); + if (it != _channel_bindings.end()) { + sockaddr_storage addr = it->second; + for (auto& socket : _socket_candidate_manager._relay_sockets) { + auto pair = std::make_shared(socket); + sendChannelBindRequest(pair, channel_number, addr); + break; // 只需要使用一个本地候选项发送请求 + } + } + } + } +} + +bool IceAgent::setSelectedPair(const Pair::Ptr& pair) { + if (_selected_pair && Pair::is_same(pair.get(), _selected_pair.get())){ + return false; + } + + if (_selected_pair) { + InfoL << "Previous selected_pair: " << _selected_pair->dumpString(2); + InfoL << "New selected_pair: " << pair->dumpString(2); + } else { + InfoL << "Initial selected_pair: " << pair->dumpString(2); + } + + _last_selected_pair = std::static_pointer_cast(_selected_pair); + _selected_pair = pair; + return true; +} + +void IceAgent::removePair(const toolkit::SocketHelper *socket) { + // TODO +} + +std::vector IceAgent::getPairs() const { + // TODO + if (_selected_pair) { + return { _selected_pair }; + } + return {}; +} + +void IceAgent::sendSocketData(const Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush) { + auto use_pair = pair? pair : getSelectedPair(); + + if (use_pair == nullptr) { + WarnL << "pair should not be nullptr"; + return; + } + + if (use_pair->_relayed_addr) { + return sendRelayPacket(buf, use_pair, flush); + } + return sendSocketData_l(buf, use_pair, flush); +} + +void IceAgent::sendRelayPacket(const Buffer::Ptr &buffer, const Pair::Ptr &pair, bool flush) { + // TraceL; + auto forward_pair = std::make_shared(*pair); + auto peer_addr = std::move(forward_pair->_relayed_addr); + forward_pair->_relayed_addr = nullptr; + + if (!hasPermission(*peer_addr)) { + WarnL << "No permission exists for peer: " << addrToStr(*peer_addr); + return; + } + + uint16_t channel_number; + if (hasChannelBind(*peer_addr, channel_number)) { + sendChannelData(channel_number, buffer, forward_pair); + } else { + sendSendIndication(*peer_addr, buffer, forward_pair); + } +} + +CandidateInfo IceAgent::getLocalCandidateInfo(const Pair::Ptr& pair) { + // 从socket_candidate_manager中查找对应的本地候选者信息 + for (const auto& socket_candidates : _socket_candidate_manager.socket_to_candidates) { + if (socket_candidates.first == pair->_socket) { + // 找到对应socket的候选者列表,选择第一个(host类型)作为默认 + if (!socket_candidates.second.empty()) { + return socket_candidates.second[0]; + } + } + } + + throw std::invalid_argument("No candidate found for the specified socket pair"); +} + +void IceAgent::addToChecklist(const Pair::Ptr& pair, CandidateInfo& remote_candidate) { + try { + //TODO: 优化checklist + CandidateInfo local_candidate = getLocalCandidateInfo(pair); + auto candidate_pair = std::make_shared(std::make_shared(*pair), remote_candidate, local_candidate); + candidate_pair->_state = CandidateInfo::State::InProgress; + _check_list.push_back(candidate_pair); + + std::sort(_check_list.begin(), _check_list.end(), [] ( + const std::shared_ptr& a, const std::shared_ptr& b) { + return *a < *b; + }); + + InfoL << "connectivity check candidate pair " << candidate_pair->dumpString() << ", pair info: " << pair->dumpString(2); + + connectivityCheck(std::make_shared(*pair), remote_candidate); + } catch (std::exception &ex) { + WarnL << ex.what(); + } +} + +void IceTransport::checkRequestTimeouts() { + uint64_t now = toolkit::getCurrentMillisecond(); + GET_CONFIG(uint32_t, max_retry, kMaxStunRetry); + for (auto it = _response_handlers.begin(); it != _response_handlers.end();) { + auto& transaction_id = it->first; + auto& req_info = it->second; + + // 检查是否超时 + if (now >= req_info._next_timeout) { + if (req_info._retry_count >= max_retry) { + // 超过最大重传次数,放弃请求并清理 + WarnL << "STUN request timeout after " << max_retry + << " retries, transaction_id: " << hexdump(transaction_id.data(), transaction_id.size()); + it = _response_handlers.erase(it); + continue; + } else { + // 执行重传 + retransmitRequest(transaction_id, req_info); + } + } + ++it; + } +} + +void IceTransport::retransmitRequest(const std::string& transaction_id, RequestInfo& req_info) { + // 增加重传次数 + req_info._retry_count++; + + // RTO翻倍(指数退避) + req_info._rto *= 2; + + // 计算下次超时时间 + uint64_t now = toolkit::getCurrentMillisecond(); + req_info._next_timeout = now + req_info._rto; + +#if 0 + TraceL << "Retransmitting STUN request (attempt " << req_info._retry_count + << "/" << RequestInfo::MAX_RETRIES << "), RTO: " << req_info._rto + << "ms, transaction_id: " << hexdump(transaction_id.data(), transaction_id.size()); +#endif + + // 重新发送请求包 + sendPacket(req_info._request, req_info._pair); +} + +Json::Value IceAgent::getChecklistInfo() const { + Json::Value result; + + Json::Value local_candidates_array(Json::arrayValue); + auto all_local_candidates = _socket_candidate_manager.getAllCandidates(); + for (const auto& local_candidate : all_local_candidates) { + Json::Value candidate_info; + candidate_info["type"] = CandidateInfo::getAddressTypeStr(local_candidate._type); + candidate_info["host"] = local_candidate._addr._host; + candidate_info["port"] = local_candidate._addr._port; + candidate_info["priority"] = local_candidate._priority; + if (!local_candidate._base_addr._host.empty()) { + candidate_info["base_host"] = local_candidate._base_addr._host; + candidate_info["base_port"] = local_candidate._base_addr._port; + } + local_candidates_array.append(candidate_info); + } + result["local_candidates"] = local_candidates_array; + result["local_candidates_count"] = static_cast(all_local_candidates.size()); + + Json::Value remote_candidates_array(Json::arrayValue); + for (const auto& remote_candidate : _remote_candidates) { + Json::Value candidate_info; + candidate_info["type"] = CandidateInfo::getAddressTypeStr(remote_candidate._type); + candidate_info["host"] = remote_candidate._addr._host; + candidate_info["port"] = remote_candidate._addr._port; + candidate_info["priority"] = remote_candidate._priority; + if (!remote_candidate._base_addr._host.empty()) { + candidate_info["base_host"] = remote_candidate._base_addr._host; + candidate_info["base_port"] = remote_candidate._base_addr._port; + } + remote_candidates_array.append(candidate_info); + } + result["remote_candidates"] = remote_candidates_array; + result["remote_candidates_count"] = static_cast(_remote_candidates.size()); + + Json::Value checklist_array(Json::arrayValue); + for (const auto& candidate_pair : _check_list) { + Json::Value entry; + entry["candidate_pair"] = candidate_pair->_local_candidate.dumpString() + " <-> " + candidate_pair->_remote_candidate.dumpString(); + entry["state"] = CandidateInfo::getStateStr(candidate_pair->_state); + entry["priority"] = (Json::UInt64)candidate_pair->_priority; + entry["nominated"] = candidate_pair->_nominated; + checklist_array.append(entry); + } + + result["checklists"] = checklist_array; + result["checklists_count"] = (int)_check_list.size(); + result["ice_state"] = stateToString(_state); + + if (_selected_pair) { + Json::Value selected_pair; + selected_pair["local_addr"] = _selected_pair->get_local_ip() + ":" + std::to_string(_selected_pair->get_local_port()); + selected_pair["remote_addr"] = _selected_pair->get_peer_ip() + ":" + std::to_string(_selected_pair->get_peer_port()); + if (!_selected_pair->get_relayed_ip().empty()) { + selected_pair["relayed_addr"] = _selected_pair->get_relayed_ip() + ":" + std::to_string(_selected_pair->get_relayed_port()); + } + + if (_select_candidate_pair) { + selected_pair["candidate_pair"] = _select_candidate_pair->_local_candidate.dumpString() + " <-> " + _select_candidate_pair->_remote_candidate.dumpString(); + } + + result["selected_pair"] = selected_pair; + } else { + result["selected_pair"] = Json::nullValue; + } + return result; +} + +size_t IceAgent::getRecvSpeed() { + size_t ret = 0; + for (auto s : _socket_candidate_manager.getAllSockets()) { + if (s && s->getSock()) { + ret += s->getSock()->getRecvSpeed(); + } + } + return ret; +} + +size_t IceAgent::getRecvTotalBytes() { + size_t ret = 0; + for (auto s : _socket_candidate_manager.getAllSockets()) { + if (s && s->getSock()) { + ret += s->getSock()->getRecvTotalBytes(); + } + } + return ret; +} + +size_t IceAgent::getSendSpeed() { + size_t ret = 0; + for (auto s : _socket_candidate_manager.getAllSockets()) { + if (s && s->getSock()) { + ret += s->getSock()->getSendSpeed(); + } + } + return ret; +} + +size_t IceAgent::getSendTotalBytes() { + size_t ret = 0; + for (auto s : _socket_candidate_manager.getAllSockets()) { + if (s && s->getSock()) { + ret += s->getSock()->getSendTotalBytes(); + } + } + return ret; +} +} // namespace RTC diff --git a/webrtc/IceTransport.hpp b/webrtc/IceTransport.hpp new file mode 100644 index 00000000..0d099990 --- /dev/null +++ b/webrtc/IceTransport.hpp @@ -0,0 +1,758 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. +*/ + +#ifndef ZLMEDIAKIT_WEBRTC_ICE_TRANSPORT_HPP +#define ZLMEDIAKIT_WEBRTC_ICE_TRANSPORT_HPP + +#include +#include +#include +#include +#include +#include +#include +#include "json/json.h" +#include "Util/Byte.hpp" +#include "Poller/Timer.h" +#include "Poller/EventPoller.h" +#include "Network/Socket.h" +#include "Network/UdpClient.h" +#include "Network/Session.h" +#include "logger.h" +#include "StunPacket.hpp" + +namespace RTC { + +uint64_t calCandidatePairPriority(uint32_t G, uint32_t D); + +class CandidateAddr { +public: + + bool operator==(const CandidateAddr& rhs) const { + return ((_host == rhs._host) && (_port == rhs._port)); + } + + bool operator!=(const CandidateAddr& rhs) const { + return !(*this == rhs); + } + + std::string dumpString() const { + return _host + ":" + std::to_string(_port); + } + +public: + std::string _host; + uint16_t _port = 0; +}; + + +class CandidateTuple { +public: + using Ptr = std::shared_ptr; + CandidateTuple() = default; + virtual ~CandidateTuple() = default; + + enum class AddressType { + HOST = 1, + SRFLX, //server reflexive + PRFLX, //peer reflexive + RELAY, + }; + + enum class SecureType { + NOT_SECURE = 1, + SECURE, + }; + + enum class TransportType { + UDP = 1, + TCP, + }; + + bool operator<(const CandidateTuple& rhs) const { + return (_priority < rhs._priority); + } + + bool operator==(const CandidateTuple& rhs) const { + return ((_addr == rhs._addr) + && (_priority == rhs._priority) + && (_transport == rhs._transport) && (_secure == rhs._secure)); + } + + struct ClassHash { + std::size_t operator()(const CandidateTuple& t) const { + std::string str = t._addr._host + std::to_string(t._addr._port) + + std::to_string((uint32_t)t._transport) + std::to_string((uint32_t)t._secure); + return std::hash()(str); + } + }; + + struct ClassEqual { + bool operator()(const CandidateTuple& a, const CandidateTuple& b) const { + return a == b; + } + }; + +public: + CandidateAddr _addr; + uint32_t _priority = 0; + TransportType _transport = TransportType::UDP; + SecureType _secure = SecureType::NOT_SECURE; + std::string _ufrag; + std::string _pwd; +}; + +class CandidateInfo : public CandidateTuple { +public: + using Ptr = std::shared_ptr; + CandidateInfo() = default; + virtual ~CandidateInfo() = default; + + enum class AddressType { + INVALID = 0, + HOST = 1, + SRFLX, // server reflx + PRFLX, // peer reflx + RELAY, + }; + + enum class State { + Frozen = 1, //尚未check,并还不需要check + Waiting, //尚未发送check,但也不是Frozen + InProgress, //已经发起check,但是仍在进行中 + Succeeded, //check success + Failed, //check failed + }; + + bool operator==(const CandidateInfo& rhs) const { + return CandidateTuple::operator==(rhs) && (_type == rhs._type); + } + + std::string getAddressTypeStr() const { + return getAddressTypeStr(_type); + } + + // 获取候选者地址类型字符串的静态函数 + static std::string getAddressTypeStr(CandidateInfo::AddressType type) { + switch (type) { + case CandidateInfo::AddressType::HOST: return "host"; + case CandidateInfo::AddressType::SRFLX: return "srflx"; + case CandidateInfo::AddressType::PRFLX: return "reflx"; + case CandidateInfo::AddressType::RELAY: return "relay"; + default: return "invalid"; + } + } + + static std::string getStateStr(State state) { + switch (state) { + case State::Frozen: return "frozen"; + case State::Waiting: return "waiting"; + case State::InProgress: return "in_progress"; + case State::Succeeded: return "succeeded"; + case State::Failed: return "failed"; + default: break; + } + return "unknown"; + } + + std::string dumpString() const { + return getAddressTypeStr() + " " + _addr.dumpString(); + } + +public: + AddressType _type = AddressType::HOST; + CandidateAddr _base_addr; +}; + +// ice stun/turn服务器配置 +// 格式为: (stun/turn)[s]:host:port[?transport=(tcp/udp)], 默认udp模式 +// 例如: +// stun:stun.l.google.com:19302 → 谷歌的 STUN 服务器(UDP)。 +// turn:turn.example.com:3478?transport=tcp → 使用 TCP 的 TURN 服务器。 +// turns:turn.example.com:5349 → 使用 TLS 的 TURN 服务器。 +class IceServerInfo : public CandidateTuple { +public: + using Ptr = std::shared_ptr; + IceServerInfo() = default; + virtual ~IceServerInfo() = default; + IceServerInfo(const std::string &url) { parse(url); } + void parse(const std::string &url); + + enum class SchemaType { + TURN = 1, + STUN, + }; + +public: + std::string _full_url; + std::string _param_strs; + SchemaType _schema = SchemaType::TURN; +}; + +class IceTransport : public std::enable_shared_from_this { +public: + using Ptr = std::shared_ptr; + + class Pair { + public: + using Ptr = std::shared_ptr; + + Pair() = default; + Pair(toolkit::SocketHelper::Ptr socket) : _socket(std::move(socket)) {} + Pair(toolkit::SocketHelper::Ptr socket, std::string peer_host, uint16_t peer_port, + std::shared_ptr relayed_addr = nullptr) : + _socket(std::move(socket)), _peer_host(std::move(peer_host)), _peer_port(peer_port), _relayed_addr(std::move(relayed_addr)) { + } + + Pair(Pair &that) { + _socket = that._socket; + _peer_host = that._peer_host; + _peer_port = that._peer_port; + _relayed_addr = nullptr; + if (that._relayed_addr) { + _relayed_addr = std::make_shared(); + memcpy(_relayed_addr.get(), that._relayed_addr.get(), sizeof(sockaddr_storage)); + } + } + virtual ~Pair() = default; + + void get_peer_addr(sockaddr_storage &peer_addr) const { + if (!_peer_host.empty()) { + peer_addr = toolkit::SockUtil::make_sockaddr(_peer_host.data(), _peer_port); + } else { + auto addr = _socket->get_peer_addr(); + if (addr->sa_family == AF_INET6 && IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)addr)->sin6_addr)) { + memset(&peer_addr, 0, sizeof(peer_addr)); + // 转换IPv6 v4mapped地址为IPv4地址 + struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)addr; + struct sockaddr_in *addr4 = (struct sockaddr_in *)&peer_addr; + addr4->sin_family = AF_INET; + addr4->sin_port = addr6->sin6_port; + memcpy(&addr4->sin_addr, &addr6->sin6_addr.s6_addr[12], 4); + } else { + memcpy(&peer_addr, addr, toolkit::SockUtil::get_sock_len(addr)); + } + } + } + + bool get_relayed_addr(sockaddr_storage &peerAddr) const { + if (!_relayed_addr) { + return false; + } + + memcpy(&peerAddr, _relayed_addr.get(), sizeof(peerAddr)); + return true; + } + + std::string get_local_ip() const { return _socket->get_local_ip(); } + + uint16_t get_local_port() const { return _socket->get_local_port(); } + + std::string get_peer_ip() const { return !_peer_host.empty() ? _peer_host : _socket->get_peer_ip(); } + + uint16_t get_peer_port() const { return !_peer_host.empty() ? _peer_port : _socket->get_peer_port(); } + + + std::string get_relayed_ip() const { return _relayed_addr ? toolkit::SockUtil::inet_ntoa((const struct sockaddr *)_relayed_addr.get()) : ""; } + + uint16_t get_relayed_port() const { return _relayed_addr ? toolkit::SockUtil::inet_port((const struct sockaddr *)_relayed_addr.get()) : 0; } + + static bool is_same_relayed_addr(Pair *a, Pair *b) { + if (a->_relayed_addr && b->_relayed_addr) { + return toolkit::SockUtil::is_same_addr( + reinterpret_cast(a->_relayed_addr.get()), reinterpret_cast(b->_relayed_addr.get())); + } + return (a->_relayed_addr == b->_relayed_addr); + } + + static bool is_same(Pair* a, Pair* b) { + // FIXME: a->_socket == b->_socket条件成立后,后面get_peer_ip和get_peer_port一定相同 + if ((a->_socket == b->_socket) + && (a->get_peer_ip() == b->get_peer_ip()) + && (a->get_peer_port() == b->get_peer_port()) + && (is_same_relayed_addr(a, b))) { + return true; + } + return false; + } + + std::string dumpString(uint8_t flag) const { + toolkit::_StrPrinter sp; + static const char* fStr[] = { "<-", "->", "<->" }; + sp << (_socket ? (_socket->getSock()->sockType() == toolkit::SockNum::Sock_TCP ? "tcp " : "udp ") : "") + << get_local_ip() << ":" << get_local_port() << fStr[flag] << get_peer_ip() << ":" << get_peer_port(); + if (_relayed_addr && flag == 2) { + sp << " relay " << get_relayed_ip() << ":" << get_relayed_port(); + } + return sp; + } + public: + toolkit::SocketHelper::Ptr _socket; + //对端host:port 地址,因为多个pair会复用一个socket对象,因此可能会和_socket的创建bind信息不一致 + std::string _peer_host; + uint16_t _peer_port; + + //中继后地址,用于实现TURN转发地址,当该地址不为空时,该地址为真正的peer地址,_peer_host和_peer_port表示中继地址 + std::shared_ptr _relayed_addr; + }; + + class Listener { + public: + virtual ~Listener() = default; + + public: + virtual void onIceTransportRecvData(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair) = 0; + virtual void onIceTransportGatheringCandidate(const Pair::Ptr&, const CandidateInfo&) = 0; + virtual void onIceTransportDisconnected() = 0; + virtual void onIceTransportCompleted() = 0; + }; + +public: + using MsgHandler = std::function; + + struct RequestInfo { + StunPacket::Ptr _request; // 原始请求包 + MsgHandler _handler; // 响应处理函数 + Pair::Ptr _pair; // 发送对 + uint64_t _send_time; // 首次发送时间(毫秒) + uint64_t _next_timeout; // 下次超时时间(毫秒) + uint32_t _retry_count; // 当前重传次数 + uint32_t _rto = 500; // 当前RTO值(毫秒) 初始RTO 500ms + + RequestInfo(StunPacket::Ptr req, MsgHandler h, Pair::Ptr p) + : _request(std::move(req)) + , _handler(std::move(h)) + , _pair(std::move(p)) + , _retry_count(0) { + _send_time = toolkit::getCurrentMillisecond(); + _next_timeout = _send_time + _rto; + } + }; + + IceTransport(Listener* listener, std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller); + virtual ~IceTransport() {} + + virtual void initialize(); + + const toolkit::EventPoller::Ptr& getPoller() const { return _poller; } + const std::string& getIdentifier() const { return _identifier; } + + const std::string& getUfrag() const { return _ufrag; } + const std::string& getPassword() const { return _password; } + void setUfrag(std::string ufrag) { _ufrag = std::move(ufrag); } + void setPassword(std::string password) { _password = std::move(password); } + + virtual bool processSocketData(const uint8_t* data, size_t len, const Pair::Ptr& pair); + virtual void sendSocketData(const toolkit::Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush = true); + void sendSocketData_l(const toolkit::Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush = true); + +protected: + virtual void processStunPacket(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + virtual void processRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + virtual void processResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + virtual bool processChannelData(const uint8_t* data, size_t len, const Pair::Ptr& pair); + virtual StunPacket::Authentication checkRequestAuthentication(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + StunPacket::Authentication checkResponseAuthentication(const StunPacket::Ptr& request, const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void processUnauthorizedResponse(const StunPacket::Ptr& response, const StunPacket::Ptr& request, const Pair::Ptr& pair, MsgHandler handler); + virtual void handleBindingRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + virtual void handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) {}; + + void sendChannelData(uint16_t channel_number, const toolkit::Buffer::Ptr &buffer, const Pair::Ptr& pair); + virtual void sendUnauthorizedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void sendErrorResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, StunAttrErrorCode::Code errorCode); + void sendRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair, MsgHandler handler); + void sendPacket(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + + // For permissions + bool hasPermission(const sockaddr_storage& addr); + void addPermission(const sockaddr_storage& addr); + + // For Channel Bind + bool hasChannelBind(uint16_t channel_number); + bool hasChannelBind(const sockaddr_storage& addr, uint16_t& channel_number); + void addChannelBind(uint16_t channel_number, const sockaddr_storage& addr); + + toolkit::SocketHelper::Ptr createSocket(CandidateTuple::TransportType type, const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port = 0); + toolkit::SocketHelper::Ptr createUdpSocket(const std::string &target_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port); + + void checkRequestTimeouts(); + void retransmitRequest(const std::string& transaction_id, RequestInfo& req_info); + +protected: + std::string _identifier; + toolkit::EventPoller::Ptr _poller; + Listener* _listener = nullptr; + std::unordered_map _response_handlers; + std::unordered_map, MsgHandler, StunPacket::ClassMethodHash> _request_handlers; + + // for local + std::string _ufrag; + std::string _password; + + // For permissions + std::unordered_map _permissions; + + // For Channel Bind + std::unordered_map _channel_bindings; + std::unordered_map _channel_binding_times; + + // For STUN request retry + std::shared_ptr _retry_timer; +}; + +class IceServer : public IceTransport { +public: + using Ptr = std::shared_ptr; + using WeakPtr = std::weak_ptr; + IceServer(Listener* listener, std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller); + virtual ~IceServer() {} + + bool processSocketData(const uint8_t* data, size_t len, const Pair::Ptr& pair) override; + void relayForwordingData(const toolkit::Buffer::Ptr& buffer, const sockaddr_storage& peer_addr); + void relayBackingData(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair, const sockaddr_storage& peer_addr); + +protected: + void processRelayPacket(const toolkit::Buffer::Ptr &buffer, const Pair::Ptr& pair); + void handleAllocateRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleRefreshRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleCreatePermissionRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleChannelBindRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleSendIndication(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) override; + + StunPacket::Authentication checkRequestAuthentication(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override; + + void sendDataIndication(const sockaddr_storage& peer_addr, const toolkit::Buffer::Ptr &buffer, const Pair::Ptr& pair); + void sendUnauthorizedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override; + + toolkit::SocketHelper::Ptr allocateRelayed(const Pair::Ptr& pair); + toolkit::SocketHelper::Ptr createRelayedUdpSocket(const std::string &peer_host, uint16_t peer_port, const std::string &local_ip, uint16_t local_port); + +protected: + std::vector _nonce_list; + + std::unordered_map /* port */, Pair::Ptr /*relayed_pairs*/>, + toolkit::SockUtil::SockAddrHash, toolkit::SockUtil::SockAddrEqual> _relayed_pairs; + Pair::Ptr _session_pair; +}; + +class IceAgent : public IceTransport { + +public: + using Ptr = std::shared_ptr; + + // 候选者对信息结构 + struct CandidatePair { + Pair::Ptr _local_pair; // 本地候选者对 + CandidateInfo _remote_candidate; // 远程候选者信息 + CandidateInfo _local_candidate; // 本地候选者信息 + uint64_t _priority; // 候选者对优先级(64位,符合RFC 8445) + CandidateInfo::State _state; // 连通性检查状态 + bool _nominated = false; + + CandidatePair(Pair::Ptr local_pair, CandidateInfo remote, CandidateInfo local) + : _local_pair(std::move(local_pair)) + , _remote_candidate(std::move(remote)) + , _local_candidate(std::move(local)) + , _state(CandidateInfo::State::Frozen) { + _priority = calCandidatePairPriority(local._priority, remote._priority); + } + std::string dumpString() const { + return "local " + _local_candidate.dumpString() + " <-> remote " + _remote_candidate.dumpString(); + } + // 比较操作符,用于优先级排序(高优先级在前) + bool operator<(const CandidatePair& other) const { + return _priority > other._priority; + } + }; + + enum class State { + //checklist state and ice session state + Running = 1, //正在进行候选地址的连通性检测 + Nominated, //发起提名,等待应答 + Completed, //所有候选地址完成验证,且至少有一路连接检测成功 + Failed, //所有候选地址检测失败,连接不可用 + }; + + static const char* stateToString(State state) { + switch (state) { + case State::Running: return "Running"; + case State::Completed: return "Completed"; + case State::Failed: return "Failed"; + default: return "Unknown"; + } + } + + enum class Role { + Controlling = 1, + Controlled, + }; + + enum class Implementation { + Lite = 1, + Full, + }; + + IceAgent(Listener* listener, Implementation implementation, Role role, + std::string ufrag, std::string password, toolkit::EventPoller::Ptr poller); + virtual ~IceAgent() {} + + void setIceServer(IceServerInfo::Ptr ice_server) { + _ice_server = std::move(ice_server); + } + + void gatheringCandidate(const CandidateTuple::Ptr& candidate_tuple, bool gathering_rflx, bool gathering_relay); + void connectivityCheck(CandidateInfo& candidate); + void nominated(const Pair::Ptr& pair, CandidateTuple& candidate); + + void sendSocketData(const toolkit::Buffer::Ptr& buf, const Pair::Ptr& pair, bool flush = true) override; + + IceAgent::Implementation getImplementation() const { + return _implementation; + } + + void setgetImplementation(IceAgent::Implementation implementation) { + InfoL << (uint32_t)implementation; + _implementation = implementation; + } + + IceAgent::Role getRole() const { + return _role; + } + + void setRole(IceAgent::Role role) { + InfoL << (uint32_t)role; + _role = role; + } + + IceAgent::State getState() const { + return _state; + } + + void setState(IceAgent::State state) { + InfoL << stateToString(state); + _state = state; + } + + Pair::Ptr getSelectedPair(bool try_last = false) const { + return try_last ? _last_selected_pair.lock() : _selected_pair; + } + bool setSelectedPair(const Pair::Ptr& pair); + + void removePair(const toolkit::SocketHelper *socket); + + std::vector getPairs() const; + + // 获取checklist信息,用于API查询 + Json::Value getChecklistInfo() const; + size_t getRecvSpeed(); + size_t getRecvTotalBytes(); + size_t getSendSpeed(); + size_t getSendTotalBytes(); + +protected: + void gatheringSrflxCandidate(const Pair::Ptr& pair); + void gatheringRelayCandidate(const Pair::Ptr& pair); + void localRelayedConnectivityCheck(CandidateInfo& candidate); + void connectivityCheck(const Pair::Ptr& pair, CandidateTuple& candidate); + void tryTriggerredCheck(const Pair::Ptr& pair); + + void sendBindRequest(const Pair::Ptr& pair, MsgHandler handler); + void sendBindRequest(const Pair::Ptr& pair, CandidateTuple& candidate, bool use_candidate, MsgHandler handler); + void sendAllocateRequest(const Pair::Ptr& pair); + void sendCreatePermissionRequest(const Pair::Ptr& pair, const sockaddr_storage& peer_addr); + void sendChannelBindRequest(const Pair::Ptr& pair, uint16_t channel_number, const sockaddr_storage& peer_addr); + + void processRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override; + + void handleBindingRequest(const StunPacket::Ptr& packet, const Pair::Ptr& pair) override; + void handleGatheringCandidateResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleConnectivityCheckResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, CandidateTuple& candidate); + void handleNominatedResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, CandidateTuple& candidate); + void handleAllocateResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleCreatePermissionResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, const sockaddr_storage& peer_addr); + void handleChannelBindResponse(const StunPacket::Ptr& packet, const Pair::Ptr& pair, uint16_t channel_number, const sockaddr_storage& peer_addr); + void handleDataIndication(const StunPacket::Ptr& packet, const Pair::Ptr& pair); + void handleChannelData(uint16_t channel_number, const char* data, size_t len, const Pair::Ptr& pair) override; + + void onGatheringCandidate(const Pair::Ptr& pair, CandidateInfo& candidate); + void onConnected(const Pair::Ptr& pair); + void onCompleted(const Pair::Ptr& pair); + + void refreshPermissions(); + void refreshChannelBindings(); + + void sendSendIndication(const sockaddr_storage& peer_addr, const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair); + void sendRelayPacket(const toolkit::Buffer::Ptr& buffer, const Pair::Ptr& pair, bool flush); + +private: + + CandidateInfo getLocalCandidateInfo(const Pair::Ptr& local_pair); + void addToChecklist(const Pair::Ptr& local_pair, CandidateInfo& remote_candidate); + +protected: + IceServerInfo::Ptr _ice_server; + + std::shared_ptr _refresh_timer; + + // for candidate + + Implementation _implementation = Implementation::Full; + Role _role = Role::Controlling; //ice role + uint64_t _tiebreaker = 0; // 8 bytes unsigned integer. + State _state = IceAgent::State::Running; //ice session state + + Pair::Ptr _selected_pair; + Pair::Ptr _nominated_pair; + StunPacket::Ptr _nominated_response; + std::weak_ptr _last_selected_pair; + + // 双向索引的候选地址管理结构 + struct SocketCandidateManager { + // socket -> candidates 的一对多映射 + std::unordered_map> socket_to_candidates; + + // candidate -> socket 的映射(用于快速查找) + std::unordered_map candidate_to_socket; + + // 按类型分组的socket列表,方便遍历 + std::vector _host_sockets; // HOST类型socket + std::vector _relay_sockets; // RELAY类型socket + + bool _has_relayed_candidate = false; + + // 添加映射关系,带5元组重复检查 + bool addMapping(toolkit::SocketHelper::Ptr socket, const CandidateInfo& candidate) { + // 检查5元组是否已存在 + if (candidate_to_socket.find(candidate) != candidate_to_socket.end()) { + return false; // 已存在相同的5元组 + } + + socket_to_candidates[socket].push_back(candidate); + candidate_to_socket[candidate] = socket; + + // 按类型分组 + if (candidate._type != CandidateInfo::AddressType::RELAY) { + addHostSocket(std::move(socket)); + } else if (candidate._type == CandidateInfo::AddressType::RELAY) { + addRelaySocket(std::move(socket)); + } + + return true; + } + + // 获取socket对应的所有candidates + std::vector getCandidates(const toolkit::SocketHelper::Ptr& socket) const { + auto it = socket_to_candidates.find(socket); + return (it != socket_to_candidates.end()) ? it->second : std::vector(); + } + + // 获取candidate对应的socket + toolkit::SocketHelper::Ptr getSocket(const CandidateInfo& candidate) const { + auto it = candidate_to_socket.find(candidate); + return (it != candidate_to_socket.end()) ? it->second : nullptr; + } + + // 获取所有socket(便于遍历) + std::vector getAllSockets() const { + std::vector result; + result.reserve(_host_sockets.size() + _relay_sockets.size()); + result.insert(result.end(), _host_sockets.begin(), _host_sockets.end()); + result.insert(result.end(), _relay_sockets.begin(), _relay_sockets.end()); + return result; + } + + // 获取所有candidates(便于遍历) + std::vector getAllCandidates() const { + std::vector result; + for (auto& pair : candidate_to_socket) { + result.push_back(pair.first); + } + return result; + } + + // 直接添加host socket + void addHostSocket(toolkit::SocketHelper::Ptr socket) { + if (std::find(_host_sockets.begin(), _host_sockets.end(), socket) == _host_sockets.end()) { + _host_sockets.emplace_back(std::move(socket)); + } + } + + // 直接添加relay socket + void addRelaySocket(toolkit::SocketHelper::Ptr socket) { + if (std::find(_relay_sockets.begin(), _relay_sockets.end(), socket) == _relay_sockets.end()) { + _relay_sockets.emplace_back(std::move(socket)); + } + } + + // 获取host sockets + const std::vector& getHostSockets() const { + return _host_sockets; + } + + // 获取relay sockets + const std::vector& getRelaySockets() const { + return _relay_sockets; + } + + // 移除host socket + void removeHostSocket(const toolkit::SocketHelper::Ptr& socket) { + auto it = std::find(_host_sockets.begin(), _host_sockets.end(), socket); + if (it != _host_sockets.end()) { + _host_sockets.erase(it); + } + } + + // 移除relay socket + void removeRelaySocket(const toolkit::SocketHelper::Ptr& socket) { + auto it = std::find(_relay_sockets.begin(), _relay_sockets.end(), socket); + if (it != _relay_sockets.end()) { + _relay_sockets.erase(it); + } + } + + // 清空host sockets + void clearHostSockets() { + _host_sockets.clear(); + } + + // 清空relay sockets + void clearRelaySockets() { + _relay_sockets.clear(); + } + + // 获取host socket数量 + size_t getHostSocketCount() const { + return _host_sockets.size(); + } + + // 获取relay socket数量 + size_t getRelaySocketCount() const { + return _relay_sockets.size(); + } + }; + + //for GATHERING_CANDIDATE + SocketCandidateManager _socket_candidate_manager; //local candidates + + //for CONNECTIVITY_CHECK + using CandidateSet = std::unordered_set; + CandidateSet _remote_candidates; + + //TODO:当前仅支持多数据流复用一个checklist + std::vector> _check_list; + std::vector> _valid_list; + std::shared_ptr _select_candidate_pair; + +}; + +} // namespace RTC +#endif //ZLMEDIAKIT_WEBRTC_ICE_TRANSPORT_HPP diff --git a/webrtc/Nack.cpp b/webrtc/Nack.cpp index 91231239..a376bad4 100644 --- a/webrtc/Nack.cpp +++ b/webrtc/Nack.cpp @@ -47,6 +47,7 @@ const string kNackIntervalRatio = RTC_FIELD "nackIntervalRatio"; // nack包中rtp个数,减小此值可以让nack包响应更灵敏 [AUTO-TRANSLATED:12393868] // Number of rtp in nack packet, reducing this value can make nack packet response more sensitive const string kNackRtpSize = RTC_FIELD "nackRtpSize"; +const string kNackAudioRtpSize = RTC_FIELD "nackAudioRtpSize"; static onceToken token([]() { mINI::Instance()[kMaxRtpCacheMS] = 5 * 1000; @@ -56,6 +57,7 @@ static onceToken token([]() { mINI::Instance()[kNackMaxCount] = 15; mINI::Instance()[kNackIntervalRatio] = 1.0f; mINI::Instance()[kNackRtpSize] = 8; + mINI::Instance()[kNackAudioRtpSize] = 4; }); } // namespace Rtc @@ -156,7 +158,8 @@ int64_t NackList::getNtpStamp(uint16_t seq) { //////////////////////////////////////////////////////////////////////////////////////////////// -NackContext::NackContext() { +NackContext::NackContext(TrackType type) { + _type = type; setOnNack(nullptr); } @@ -218,7 +221,9 @@ void NackContext::makeNack(uint16_t max_seq, bool flush) { // 最多生成5个nack包,防止seq大幅跳跃导致一直循环 [AUTO-TRANSLATED:9cc5da25] // Generate at most 5 nack packets to prevent seq from jumping significantly and causing continuous loops auto max_nack = 5u; - GET_CONFIG(uint32_t, nack_rtpsize, Rtc::kNackRtpSize); + GET_CONFIG(uint32_t, nack_video_rtpsize, Rtc::kNackRtpSize); + GET_CONFIG(uint32_t, nack_audio_rtpsize, Rtc::kNackAudioRtpSize); + auto nack_rtpsize = _type == TrackVideo ? nack_video_rtpsize : nack_audio_rtpsize; // kNackRtpSize must between 0 and 16 nack_rtpsize = std::min(nack_rtpsize, FCI_NACK::kBitSize); while (_nack_seq != max_seq && max_nack--) { diff --git a/webrtc/Nack.h b/webrtc/Nack.h index 1ebfda71..87fc4c59 100644 --- a/webrtc/Nack.h +++ b/webrtc/Nack.h @@ -1,90 +1,91 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_NACK_H -#define ZLMEDIAKIT_NACK_H - -#include -#include -#include -#include -#include "Rtsp/Rtsp.h" -#include "Rtcp/RtcpFCI.h" - -namespace mediakit { - -// RTC配置项目 [AUTO-TRANSLATED:19940011] -// RTC configuration project -namespace Rtc { -// ~ nack发送端,rtp接收端 [AUTO-TRANSLATED:bb169205] -// ~ nack sender, rtp receiver -// 最大保留的rtp丢包状态个数 [AUTO-TRANSLATED:70eee442] -// Maximum number of retained rtp packet loss states -extern const std::string kNackMaxSize; -// rtp丢包状态最长保留时间 [AUTO-TRANSLATED:f9306375] -// Maximum retention time for rtp packet loss states -extern const std::string kNackMaxMS; -} // namespace Rtc - -class NackList { -public: - void pushBack(RtpPacket::Ptr rtp); - void forEach(const FCI_NACK &nack, const std::function &cb); - -private: - void popFront(); - uint32_t getCacheMS(); - int64_t getNtpStamp(uint16_t seq); - RtpPacket::Ptr *getRtp(uint16_t seq); - -private: - uint32_t _cache_ms_check = 0; - std::deque _nack_cache_seq; - std::unordered_map _nack_cache_pkt; -}; - -class NackContext { -public: - using Ptr = std::shared_ptr; - using onNack = std::function; - - NackContext(); - - void received(uint16_t seq, bool is_rtx = false); - void setOnNack(onNack cb); - uint64_t reSendNack(); - -private: - void eraseFrontSeq(); - void doNack(const FCI_NACK &nack, bool record_nack); - void recordNack(const FCI_NACK &nack); - void clearNackStatus(uint16_t seq); - void makeNack(uint16_t max, bool flush = false); - -private: - bool _started = false; - int _rtt = 50; - onNack _cb; - std::set _seq; - // 最新nack包中的rtp seq值 [AUTO-TRANSLATED:6984d95a] - // RTP seq value in the latest nack packet - uint16_t _nack_seq = 0; - - struct NackStatus { - uint64_t first_stamp; - uint64_t update_stamp; - uint32_t nack_count = 0; - }; - std::map _nack_send_status; -}; - -} // namespace mediakit - -#endif //ZLMEDIAKIT_NACK_H +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_NACK_H +#define ZLMEDIAKIT_NACK_H + +#include +#include +#include +#include +#include "Rtsp/Rtsp.h" +#include "Rtcp/RtcpFCI.h" + +namespace mediakit { + +// RTC配置项目 [AUTO-TRANSLATED:19940011] +// RTC configuration project +namespace Rtc { +// ~ nack发送端,rtp接收端 [AUTO-TRANSLATED:bb169205] +// ~ nack sender, rtp receiver +// 最大保留的rtp丢包状态个数 [AUTO-TRANSLATED:70eee442] +// Maximum number of retained rtp packet loss states +extern const std::string kNackMaxSize; +// rtp丢包状态最长保留时间 [AUTO-TRANSLATED:f9306375] +// Maximum retention time for rtp packet loss states +extern const std::string kNackMaxMS; +} // namespace Rtc + +class NackList { +public: + void pushBack(RtpPacket::Ptr rtp); + void forEach(const FCI_NACK &nack, const std::function &cb); + +private: + void popFront(); + uint32_t getCacheMS(); + int64_t getNtpStamp(uint16_t seq); + RtpPacket::Ptr *getRtp(uint16_t seq); + +private: + uint32_t _cache_ms_check = 0; + std::deque _nack_cache_seq; + std::unordered_map _nack_cache_pkt; +}; + +class NackContext { +public: + using Ptr = std::shared_ptr; + using onNack = std::function; + + NackContext(TrackType type = TrackVideo); + + void received(uint16_t seq, bool is_rtx = false); + void setOnNack(onNack cb); + uint64_t reSendNack(); + +private: + void eraseFrontSeq(); + void doNack(const FCI_NACK &nack, bool record_nack); + void recordNack(const FCI_NACK &nack); + void clearNackStatus(uint16_t seq); + void makeNack(uint16_t max, bool flush = false); + +private: + bool _started = false; + int _rtt = 50; + TrackType _type; + onNack _cb; + std::set _seq; + // 最新nack包中的rtp seq值 [AUTO-TRANSLATED:6984d95a] + // RTP seq value in the latest nack packet + uint16_t _nack_seq = 0; + + struct NackStatus { + uint64_t first_stamp; + uint64_t update_stamp; + uint32_t nack_count = 0; + }; + std::map _nack_send_status; +}; + +} // namespace mediakit + +#endif //ZLMEDIAKIT_NACK_H diff --git a/webrtc/RtpExt.cpp b/webrtc/RtpExt.cpp index 935075bd..b9a405d6 100644 --- a/webrtc/RtpExt.cpp +++ b/webrtc/RtpExt.cpp @@ -1,659 +1,659 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#include "RtpExt.h" -#include "Sdp.h" - -#pragma pack(push, 1) - -using namespace std; -using namespace toolkit; - -namespace mediakit { - -//https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 -//https://tools.ietf.org/html/rfc5285 - -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | 0xBE | 0xDE | length=3 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | L=0 | data | ID | L=1 | data... -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// ...data | 0 (pad) | 0 (pad) | ID | L=3 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | data | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -class RtpExtOneByte { -public: - static constexpr uint16_t kMinSize = 1; - size_t getSize() const; - uint8_t getId() const; - void setId(uint8_t id); - uint8_t* getData(); - -private: -#if __BYTE_ORDER == __BIG_ENDIAN - uint8_t id: 4; - uint8_t len: 4; -#else - uint8_t len: 4; - uint8_t id: 4; -#endif - uint8_t data[1]; -}; - -//0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | 0x100 |appbits| length=3 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | L=0 | ID | L=1 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | data | 0 (pad) | ID | L=4 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | data | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -class RtpExtTwoByte { -public: - static constexpr uint16_t kMinSize = 2; - - size_t getSize() const; - uint8_t getId() const; - void setId(uint8_t id); - uint8_t* getData(); - -private: - uint8_t id; - uint8_t len; - uint8_t data[1]; -}; - -#pragma pack(pop) - -////////////////////////////////////////////////////////////////// - -size_t RtpExtOneByte::getSize() const { - return len + 1; -} - -uint8_t RtpExtOneByte::getId() const { - return id; -} - -void RtpExtOneByte::setId(uint8_t in) { - CHECK(in < (int)RtpExtType::reserved); - id = in & 0x0F; -} - -uint8_t *RtpExtOneByte::getData() { - return data; -} - -////////////////////////////////////////////////////////////////// - -size_t RtpExtTwoByte::getSize() const { - return len; -} - -uint8_t RtpExtTwoByte::getId() const { - return id; -} - -void RtpExtTwoByte::setId(uint8_t in) { - id = in; -} - -uint8_t *RtpExtTwoByte::getData() { - return data; -} - -////////////////////////////////////////////////////////////////// - -static constexpr uint16_t kOneByteHeader = 0xBEDE; -static constexpr uint16_t kTwoByteHeader = 0x1000; - -template -static bool isOneByteExt(){ - return false; -} - -template<> -bool isOneByteExt(){ - return true; -} - -template -void appendExt(map &ret, uint8_t *ptr, const uint8_t *end) { - while (ptr < end) { - auto ext = reinterpret_cast(ptr); - if (ext->getId() == (uint8_t) RtpExtType::padding) { - // padding,忽略 [AUTO-TRANSLATED:a7fda608] - // padding, ignore - ++ptr; - continue; - } - CHECK(reinterpret_cast(ext) + Type::kMinSize <= end); - CHECK(ext->getData() + ext->getSize() <= end); - ret.emplace(ext->getId(), RtpExt(ext, isOneByteExt(), reinterpret_cast(ext->getData()), ext->getSize())); - ptr += Type::kMinSize + ext->getSize(); - } -} - -RtpExt::RtpExt(void *ext, bool one_byte_ext, const char *str, size_t size) { - _ext = ext; - _one_byte_ext = one_byte_ext; - _data = str; - _size = size; -} - -const char *RtpExt::data() const { - return _data; -} - -size_t RtpExt::size() const { - return _size; -} - -const uint8_t& RtpExt::operator[](size_t pos) const{ - CHECK(pos < _size); - return ((uint8_t*)_data)[pos]; -} - -RtpExt::operator std::string() const{ - return string(_data, _size); -} - -map RtpExt::getExtValue(const RtpHeader *header) { - map ret; - assert(header); - auto ext_size = header->getExtSize(); - if (!ext_size) { - return ret; - } - auto reserved = header->getExtReserved(); - auto ptr = const_cast(header)->getExtData(); - auto end = ptr + ext_size; - if (reserved == kOneByteHeader) { - appendExt(ret, ptr, end); - return ret; - } - if ((reserved & 0xFFF0) == kTwoByteHeader) { - appendExt(ret, ptr, end); - return ret; - } - return ret; -} - -#define XX(type, url) {RtpExtType::type , url}, -static map s_type_to_url = {RTP_EXT_MAP(XX)}; -#undef XX - - -#define XX(type, url) {url, RtpExtType::type}, -static unordered_map s_url_to_type = {RTP_EXT_MAP(XX)}; -#undef XX - -RtpExtType RtpExt::getExtType(const string &url) { - auto it = s_url_to_type.find(url); - if (it == s_url_to_type.end()) { - WarnL << "unknown rtp ext url type: " << url; - return RtpExtType::padding; - } - return it->second; -} - -const string &RtpExt::getExtUrl(RtpExtType type) { - auto it = s_type_to_url.find(type); - if (it == s_type_to_url.end()) { - throw std::invalid_argument(string("未识别的rtp ext类型:") + to_string((int) type)); - } - return it->second; -} - -const char *RtpExt::getExtName(RtpExtType type) { -#define XX(type, url) case RtpExtType::type: return #type; - switch (type) { - RTP_EXT_MAP(XX) - default: return "unknown ext type"; - } -#undef XX -} - -string RtpExt::dumpString() const { - _StrPrinter printer; - switch (_type) { - case RtpExtType::ssrc_audio_level : { - bool vad; - printer << "audio level:" << (int) getAudioLevel(&vad) << ", vad:" << vad; - break; - } - case RtpExtType::abs_send_time : { - printer << "abs send time:" << getAbsSendTime(); - break; - } - case RtpExtType::transport_cc : { - printer << "twcc ext seq:" << getTransportCCSeq(); - break; - } - case RtpExtType::sdes_mid : { - printer << "sdes mid:" << getSdesMid(); - break; - } - case RtpExtType::sdes_rtp_stream_id : { - printer << "rtp stream id:" << getRtpStreamId(); - break; - } - case RtpExtType::sdes_repaired_rtp_stream_id : { - printer << "rtp repaired stream id:" << getRepairedRtpStreamId(); - break; - } - case RtpExtType::video_timing : { - uint8_t flags; - uint16_t encode_start, encode_finish, packetization_complete, last_pkt_left_pacer, reserved_net0, reserved_net1; - getVideoTiming(flags, encode_start, encode_finish, packetization_complete, last_pkt_left_pacer, - reserved_net0, reserved_net1); - printer << "video timing, flags:" << (int) flags - << ",encode:" << encode_start << "-" << encode_finish - << ",packetization_complete:" << packetization_complete - << ",last_pkt_left_pacer:" << last_pkt_left_pacer - << ",reserved_net0:" << reserved_net0 - << ",reserved_net1:" << reserved_net1; - break; - } - case RtpExtType::video_content_type : { - printer << "video content type:" << (int)getVideoContentType(); - break; - } - case RtpExtType::video_orientation : { - bool camera_bit, flip_bit, first_rotation, second_rotation; - getVideoOrientation(camera_bit, flip_bit, first_rotation, second_rotation); - printer << "video orientation:" << camera_bit << "-" << flip_bit << "-" << first_rotation << "-" << second_rotation; - break; - } - case RtpExtType::playout_delay : { - uint16_t min_delay, max_delay; - getPlayoutDelay(min_delay, max_delay); - printer << "playout delay:" << min_delay << "-" << max_delay; - break; - } - case RtpExtType::toffset : { - printer << "toffset:" << getTransmissionOffset(); - break; - } - case RtpExtType::framemarking : { - printer << "framemarking tid:" << (int)getFramemarkingTID(); - break; - } - default: { - printer << getExtName(_type) << ", hex:" << hexdump(data(), size()); - break; - } - } - return std::move(printer); -} - -//https://tools.ietf.org/html/rfc6464 -// 0 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | len=0 |V| level | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// -// Figure 1: Sample Audio Level Encoding Using the -// One-Byte Header Format -// -// -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | len=1 |V| level | 0 (pad) | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// -// Figure 2: Sample Audio Level Encoding Using the -// Two-Byte Header Format -uint8_t RtpExt::getAudioLevel(bool *vad) const{ - CHECK(_type == RtpExtType::ssrc_audio_level && size() >= 1); - auto &byte = (*this)[0]; - if (vad) { - *vad = byte & 0x80; - } - return byte & 0x7F; -} - -//http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time -// Wire format: 1-byte extension, 3 bytes of data. total 4 bytes extra per packet (plus shared 4 bytes for all extensions present: 2 byte magic word 0xBEDE, 2 byte # of extensions). Will in practice replace the “toffset” extension so we should see no long term increase in traffic as a result. [AUTO-TRANSLATED:178290be] -// Wire format: 1-byte extension, 3 bytes of data. total 4 bytes extra per packet (plus shared 4 bytes for all extensions present: 2 byte magic word 0xBEDE, 2 byte # of extensions). Will in practice replace the “toffset” extension so we should see no long term increase in traffic as a result. -// -//Encoding: Timestamp is in seconds, 24 bit 6.18 fixed point, yielding 64s wraparound and 3.8us resolution (one increment for each 477 bytes going out on a 1Gbps interface). -// -//Relation to NTP timestamps: abs_send_time_24 = (ntp_timestamp_64 >> 14) & 0x00ffffff ; NTP timestamp is 32 bits for whole seconds, 32 bits fraction of second. -// -//Notes: Packets are time stamped when going out, preferably close to metal. Intermediate RTP relays (entities possibly altering the stream) should remove the extension or set its own timestamp. -uint32_t RtpExt::getAbsSendTime() const { - CHECK(_type == RtpExtType::abs_send_time && size() >= 3); - uint32_t ret = 0; - ret |= (*this)[0] << 16; - ret |= (*this)[1] << 8; - ret |= (*this)[2]; - return ret; -} - -//https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | 0xBE | 0xDE | length=1 | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | L=1 |transport-wide sequence number | zero padding | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -uint16_t RtpExt::getTransportCCSeq() const { - CHECK(_type == RtpExtType::transport_cc && size() >= 2); - uint16_t ret; - ret = (*this)[0] << 8; - ret |= (*this)[1]; - return ret; -} - -//https://tools.ietf.org/html/draft-ietf-avtext-sdes-hdr-ext-07 -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | len | SDES Item text value ... | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -string RtpExt::getSdesMid() const { - CHECK(_type == RtpExtType::sdes_mid && size() >= 1); - return *this; -} - - -//https://tools.ietf.org/html/draft-ietf-avtext-rid-06 -// 用于simulcast [AUTO-TRANSLATED:59b2682f] -// Used for simulcast -//3.1. RTCP 'RtpStreamId' SDES Extension -// -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |RtpStreamId=TBD| length | RtpStreamId ... -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// -// -// The RtpStreamId payload is UTF-8 encoded and is not null-terminated. -// -// RFC EDITOR NOTE: Please replace TBD with the assigned SDES -// identifier value. - -//3.2. RTCP 'RepairedRtpStreamId' SDES Extension -// -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |Repaired...=TBD| length | RepairRtpStreamId ... -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// -// -// The RepairedRtpStreamId payload is UTF-8 encoded and is not null- -// terminated. -// -// RFC EDITOR NOTE: Please replace TBD with the assigned SDES -// identifier value. - -string RtpExt::getRtpStreamId() const { - CHECK(_type == RtpExtType::sdes_rtp_stream_id && size() >= 1); - return *this; -} - -string RtpExt::getRepairedRtpStreamId() const { - CHECK(_type == RtpExtType::sdes_repaired_rtp_stream_id && size() >= 1); - return *this; -} - - -//http://www.webrtc.org/experiments/rtp-hdrext/video-timing -//Wire format: 1-byte extension, 13 bytes of data. Total 14 bytes extra per packet (plus 1-3 padding byte in some cases, plus shared 4 bytes for all extensions present: 2 byte magic word 0xBEDE, 2 byte # of extensions). -// -//First byte is a flags field. Defined flags: -// -//0x01 - extension is set due to timer. -//0x02 - extension is set because the frame is larger than usual. -//Both flags may be set at the same time. All remaining 6 bits are reserved and should be ignored. -// -//Next, 6 timestamps are stored as 16-bit values in big-endian order, representing delta from the capture time of a packet in ms. Timestamps are, in order: -// -//Encode start. -//Encode finish. -//Packetization complete. -//Last packet left the pacer. -//Reserved for network. -//Reserved for network (2). - -void RtpExt::getVideoTiming(uint8_t &flags, - uint16_t &encode_start, - uint16_t &encode_finish, - uint16_t &packetization_complete, - uint16_t &last_pkt_left_pacer, - uint16_t &reserved_net0, - uint16_t &reserved_net1) const { - CHECK(_type == RtpExtType::video_timing && size() >= 13); - flags = (*this)[0]; - encode_start = (*this)[1] << 8 | (*this)[2]; - encode_finish = (*this)[3] << 8 | (*this)[4]; - packetization_complete = (*this)[5] << 8 | (*this)[6]; - last_pkt_left_pacer = (*this)[7] << 8 | (*this)[8]; - reserved_net0 = (*this)[9] << 8 | (*this)[10]; - reserved_net1 = (*this)[11] << 8 | (*this)[12]; -} - - -//http://www.webrtc.org/experiments/rtp-hdrext/color-space -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | L = 3 | primaries | transfer | matrix | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// |range+chr.sit. | -// +-+-+-+-+-+-+-+-+ - - -//http://www.webrtc.org/experiments/rtp-hdrext/video-content-type -//Values: -//0x00: Unspecified. Default value. Treated the same as an absence of an extension. -//0x01: Screenshare. Video stream is of a screenshare type. -// 0x02: 摄像头? [AUTO-TRANSLATED:ce2acbbb] -// 0x02: Camera? -//Notes: Extension shoud be present only in the last packet of key-frames. -// If attached to other packets it should be ignored. -// If extension is absent, Unspecified value is assumed. -uint8_t RtpExt::getVideoContentType() const { - CHECK(_type == RtpExtType::video_content_type && size() >= 1); - return (*this)[0]; -} - -//http://www.3gpp.org/ftp/Specs/html-info/26114.htm -void RtpExt::getVideoOrientation(bool &camera_bit, bool &flip_bit, bool &first_rotation, bool &second_rotation) const { - CHECK(_type == RtpExtType::video_orientation && size() >= 1); - uint8_t byte = (*this)[0]; - camera_bit = (byte & 0x08) >> 3; - flip_bit = (byte & 0x04) >> 2; - first_rotation = (byte & 0x02) >> 1; - second_rotation = byte & 0x01; -} - -//http://www.webrtc.org/experiments/rtp-hdrext/playout-delay -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -//+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -//| ID | len=2 | MIN delay | MAX delay | -//+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -void RtpExt::getPlayoutDelay(uint16_t &min_delay, uint16_t &max_delay) const { - CHECK(_type == RtpExtType::playout_delay && size() >= 3); - uint32_t bytes = (*this)[0] << 16 | (*this)[1] << 8 | (*this)[2]; - min_delay = (bytes & 0x00FFF000) >> 12; - max_delay = bytes & 0x00000FFF; -} - -//urn:ietf:params:rtp-hdrext:toffset -//https://tools.ietf.org/html/rfc5450 -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID | len=2 | transmission offset | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -uint32_t RtpExt::getTransmissionOffset() const { - CHECK(_type == RtpExtType::toffset && size() >= 3); - return (*this)[0] << 16 | (*this)[1] << 8 | (*this)[2]; -} - -//http://tools.ietf.org/html/draft-ietf-avtext-framemarking-07 -// 0 1 2 3 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | ID=? | L=2 |S|E|I|D|B| TID | LID | TL0PICIDX | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -uint8_t RtpExt::getFramemarkingTID() const { - CHECK(_type == RtpExtType::framemarking && size() >= 3); - return (*this)[0] & 0x07; -} - -void RtpExt::setExtId(uint8_t ext_id) { - assert(ext_id > (int) RtpExtType::padding && _ext); - if (_one_byte_ext) { - if (ext_id >= (int)RtpExtType::reserved) { - WarnL << "One byte rtp ext can not store id " << (int)ext_id << "(" << getExtName((RtpExtType)ext_id) << ") big than 14"; - clearExt(); - return; - } - auto ptr = reinterpret_cast(_ext); - ptr->setId(ext_id); - } else { - auto ptr = reinterpret_cast(_ext); - ptr->setId(ext_id); - } -} - -void RtpExt::clearExt(){ - assert(_ext); - if (_one_byte_ext) { - auto ptr = reinterpret_cast(_ext); - memset(ptr, (int) RtpExtType::padding, RtpExtOneByte::kMinSize + ptr->getSize()); - } else { - auto ptr = reinterpret_cast(_ext); - memset(ptr, (int) RtpExtType::padding, RtpExtTwoByte::kMinSize + ptr->getSize()); - } -} - -void RtpExt::setType(RtpExtType type) { - _type = type; -} - -RtpExtType RtpExt::getType() const { - return _type; -} - -RtpExt::operator bool() const { - return _ext != nullptr; -} - -RtpExtContext::RtpExtContext(const RtcMedia &m){ - for (auto &ext : m.extmap) { - auto ext_type = RtpExt::getExtType(ext.ext); - _rtp_ext_id_to_type.emplace(ext.id, ext_type); - _rtp_ext_type_to_id.emplace(ext_type, ext.id); - } -} - -string RtpExtContext::getRid(uint32_t ssrc) const{ - auto it = _ssrc_to_rid.find(ssrc); - if (it == _ssrc_to_rid.end()) { - return ""; - } - return it->second; -} - -void RtpExtContext::setRid(uint32_t ssrc, const string &rid) { - _ssrc_to_rid[ssrc] = rid; -} - -RtpExt RtpExtContext::changeRtpExtId(const RtpHeader *header, bool is_recv, string *rid_ptr, RtpExtType type) { - string rid, repaired_rid; - RtpExt ret; - auto ext_map = RtpExt::getExtValue(header); - for (auto &pr : ext_map) { - if (is_recv) { - auto it = _rtp_ext_id_to_type.find(pr.first); - if (it == _rtp_ext_id_to_type.end()) { - // TraceL << "接收rtp时,忽略不识别的rtp ext, id=" << (int) pr.first; [AUTO-TRANSLATED:284d8a38] - // TraceL << "Receiving rtp, ignoring unrecognized rtp ext, id=" << (int) pr.first; - pr.second.clearExt(); - continue; - } - pr.second.setType(it->second); - // 重新赋值ext id为 ext type,作为后面处理ext的统一中间类型 [AUTO-TRANSLATED:ab825878] - // Reassign ext id to ext type, as a unified intermediate type for processing ext later - pr.second.setExtId((uint8_t) it->second); - switch (it->second) { - case RtpExtType::sdes_rtp_stream_id : rid = pr.second.getRtpStreamId(); break; - case RtpExtType::sdes_repaired_rtp_stream_id : repaired_rid = pr.second.getRepairedRtpStreamId(); break; - default : break; - } - } else { - pr.second.setType((RtpExtType) pr.first); - auto it = _rtp_ext_type_to_id.find((RtpExtType) pr.first); - if (it == _rtp_ext_type_to_id.end()) { - // TraceL << "发送rtp时, 忽略不被客户端支持rtp ext:" << pr.second.dumpString(); [AUTO-TRANSLATED:5d9fd8cc] - // TraceL << "Sending rtp, ignoring rtp ext not supported by client:" << pr.second.dumpString(); - pr.second.clearExt(); - continue; - } - // 重新赋值ext id为客户端sdp声明的类型 [AUTO-TRANSLATED:06d60796] - // Reassign ext id to the type declared in client sdp - pr.second.setExtId(it->second); - } - if (pr.second.getType() == type) { - ret = pr.second; - } - } - - if (!is_recv) { - return ret; - } - if (rid.empty()) { - rid = repaired_rid; - } - auto ssrc = ntohl(header->ssrc); - if (rid.empty()) { - // 获取rid [AUTO-TRANSLATED:8ae4dffa] - // Get rid - rid = _ssrc_to_rid[ssrc]; - } else { - // 设置rid [AUTO-TRANSLATED:5e34819b] - // Set rid - auto it = _ssrc_to_rid.find(ssrc); - if (it == _ssrc_to_rid.end() || it->second != rid) { - _ssrc_to_rid[ssrc] = rid; - onGetRtp(header->pt, ssrc, rid); - } - } - if (rid_ptr) { - *rid_ptr = rid; - } - return ret; -} - -void RtpExtContext::setOnGetRtp(OnGetRtp cb) { - _cb = std::move(cb); -} - -void RtpExtContext::onGetRtp(uint8_t pt, uint32_t ssrc, const string &rid){ - if (_cb) { - _cb(pt, ssrc, rid); - } -} - -}// namespace mediakit \ No newline at end of file +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "RtpExt.h" +#include "Sdp.h" + +#pragma pack(push, 1) + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +//https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 +//https://tools.ietf.org/html/rfc5285 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | 0xBE | 0xDE | length=3 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | L=0 | data | ID | L=1 | data... +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// ...data | 0 (pad) | 0 (pad) | ID | L=3 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | data | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +class RtpExtOneByte { +public: + static constexpr uint16_t kMinSize = 1; + size_t getSize() const; + uint8_t getId() const; + void setId(uint8_t id); + uint8_t* getData(); + +private: +#if __BYTE_ORDER == __BIG_ENDIAN + uint8_t id: 4; + uint8_t len: 4; +#else + uint8_t len: 4; + uint8_t id: 4; +#endif + uint8_t data[1]; +}; + +//0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | 0x100 |appbits| length=3 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | L=0 | ID | L=1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | data | 0 (pad) | ID | L=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | data | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +class RtpExtTwoByte { +public: + static constexpr uint16_t kMinSize = 2; + + size_t getSize() const; + uint8_t getId() const; + void setId(uint8_t id); + uint8_t* getData(); + +private: + uint8_t id; + uint8_t len; + uint8_t data[1]; +}; + +#pragma pack(pop) + +////////////////////////////////////////////////////////////////// + +size_t RtpExtOneByte::getSize() const { + return len + 1; +} + +uint8_t RtpExtOneByte::getId() const { + return id; +} + +void RtpExtOneByte::setId(uint8_t in) { + CHECK(in < (int)RtpExtType::reserved); + id = in & 0x0F; +} + +uint8_t *RtpExtOneByte::getData() { + return data; +} + +////////////////////////////////////////////////////////////////// + +size_t RtpExtTwoByte::getSize() const { + return len; +} + +uint8_t RtpExtTwoByte::getId() const { + return id; +} + +void RtpExtTwoByte::setId(uint8_t in) { + id = in; +} + +uint8_t *RtpExtTwoByte::getData() { + return data; +} + +////////////////////////////////////////////////////////////////// + +static constexpr uint16_t kOneByteHeader = 0xBEDE; +static constexpr uint16_t kTwoByteHeader = 0x1000; + +template +static bool isOneByteExt(){ + return false; +} + +template<> +bool isOneByteExt(){ + return true; +} + +template +void appendExt(map &ret, uint8_t *ptr, const uint8_t *end) { + while (ptr < end) { + auto ext = reinterpret_cast(ptr); + if (ext->getId() == (uint8_t) RtpExtType::padding) { + // padding,忽略 [AUTO-TRANSLATED:a7fda608] + // padding, ignore + ++ptr; + continue; + } + CHECK(reinterpret_cast(ext) + Type::kMinSize <= end); + CHECK(ext->getData() + ext->getSize() <= end); + ret.emplace(ext->getId(), RtpExt(ext, isOneByteExt(), reinterpret_cast(ext->getData()), ext->getSize())); + ptr += Type::kMinSize + ext->getSize(); + } +} + +RtpExt::RtpExt(void *ext, bool one_byte_ext, const char *str, size_t size) { + _ext = ext; + _one_byte_ext = one_byte_ext; + _data = str; + _size = size; +} + +const char *RtpExt::data() const { + return _data; +} + +size_t RtpExt::size() const { + return _size; +} + +const uint8_t& RtpExt::operator[](size_t pos) const{ + CHECK(pos < _size); + return ((uint8_t*)_data)[pos]; +} + +RtpExt::operator std::string() const{ + return string(_data, _size); +} + +map RtpExt::getExtValue(const RtpHeader *header) { + map ret; + assert(header); + auto ext_size = header->getExtSize(); + if (!ext_size) { + return ret; + } + auto reserved = header->getExtReserved(); + auto ptr = const_cast(header)->getExtData(); + auto end = ptr + ext_size; + if (reserved == kOneByteHeader) { + appendExt(ret, ptr, end); + return ret; + } + if ((reserved & 0xFFF0) == kTwoByteHeader) { + appendExt(ret, ptr, end); + return ret; + } + return ret; +} + +#define XX(type, url) {RtpExtType::type , url}, +static map s_type_to_url = {RTP_EXT_MAP(XX)}; +#undef XX + + +#define XX(type, url) {url, RtpExtType::type}, +static unordered_map s_url_to_type = {RTP_EXT_MAP(XX)}; +#undef XX + +RtpExtType RtpExt::getExtType(const string &url) { + auto it = s_url_to_type.find(url); + if (it == s_url_to_type.end()) { + WarnL << "unknown rtp ext url type: " << url; + return RtpExtType::padding; + } + return it->second; +} + +const string &RtpExt::getExtUrl(RtpExtType type) { + auto it = s_type_to_url.find(type); + if (it == s_type_to_url.end()) { + throw std::invalid_argument(string("未识别的rtp ext类型:") + to_string((int) type)); + } + return it->second; +} + +const char *RtpExt::getExtName(RtpExtType type) { +#define XX(type, url) case RtpExtType::type: return #type; + switch (type) { + RTP_EXT_MAP(XX) + default: return "unknown ext type"; + } +#undef XX +} + +string RtpExt::dumpString() const { + _StrPrinter printer; + switch (_type) { + case RtpExtType::ssrc_audio_level : { + bool vad; + printer << "audio level:" << (int) getAudioLevel(&vad) << ", vad:" << vad; + break; + } + case RtpExtType::abs_send_time : { + printer << "abs send time:" << getAbsSendTime(); + break; + } + case RtpExtType::transport_cc : { + printer << "twcc ext seq:" << getTransportCCSeq(); + break; + } + case RtpExtType::sdes_mid : { + printer << "sdes mid:" << getSdesMid(); + break; + } + case RtpExtType::sdes_rtp_stream_id : { + printer << "rtp stream id:" << getRtpStreamId(); + break; + } + case RtpExtType::sdes_repaired_rtp_stream_id : { + printer << "rtp repaired stream id:" << getRepairedRtpStreamId(); + break; + } + case RtpExtType::video_timing : { + uint8_t flags; + uint16_t encode_start, encode_finish, packetization_complete, last_pkt_left_pacer, reserved_net0, reserved_net1; + getVideoTiming(flags, encode_start, encode_finish, packetization_complete, last_pkt_left_pacer, + reserved_net0, reserved_net1); + printer << "video timing, flags:" << (int) flags + << ",encode:" << encode_start << "-" << encode_finish + << ",packetization_complete:" << packetization_complete + << ",last_pkt_left_pacer:" << last_pkt_left_pacer + << ",reserved_net0:" << reserved_net0 + << ",reserved_net1:" << reserved_net1; + break; + } + case RtpExtType::video_content_type : { + printer << "video content type:" << (int)getVideoContentType(); + break; + } + case RtpExtType::video_orientation : { + bool camera_bit, flip_bit, first_rotation, second_rotation; + getVideoOrientation(camera_bit, flip_bit, first_rotation, second_rotation); + printer << "video orientation:" << camera_bit << "-" << flip_bit << "-" << first_rotation << "-" << second_rotation; + break; + } + case RtpExtType::playout_delay : { + uint16_t min_delay, max_delay; + getPlayoutDelay(min_delay, max_delay); + printer << "playout delay:" << min_delay << "-" << max_delay; + break; + } + case RtpExtType::toffset : { + printer << "toffset:" << getTransmissionOffset(); + break; + } + case RtpExtType::framemarking : { + printer << "framemarking tid:" << (int)getFramemarkingTID(); + break; + } + default: { + printer << getExtName(_type) << ", hex:" << hexdump(data(), size()); + break; + } + } + return printer; +} + +//https://tools.ietf.org/html/rfc6464 +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=0 |V| level | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Figure 1: Sample Audio Level Encoding Using the +// One-Byte Header Format +// +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=1 |V| level | 0 (pad) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// Figure 2: Sample Audio Level Encoding Using the +// Two-Byte Header Format +uint8_t RtpExt::getAudioLevel(bool *vad) const{ + CHECK(_type == RtpExtType::ssrc_audio_level && size() >= 1); + auto &byte = (*this)[0]; + if (vad) { + *vad = byte & 0x80; + } + return byte & 0x7F; +} + +//http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time +// Wire format: 1-byte extension, 3 bytes of data. total 4 bytes extra per packet (plus shared 4 bytes for all extensions present: 2 byte magic word 0xBEDE, 2 byte # of extensions). Will in practice replace the “toffset” extension so we should see no long term increase in traffic as a result. [AUTO-TRANSLATED:178290be] +// Wire format: 1-byte extension, 3 bytes of data. total 4 bytes extra per packet (plus shared 4 bytes for all extensions present: 2 byte magic word 0xBEDE, 2 byte # of extensions). Will in practice replace the “toffset” extension so we should see no long term increase in traffic as a result. +// +//Encoding: Timestamp is in seconds, 24 bit 6.18 fixed point, yielding 64s wraparound and 3.8us resolution (one increment for each 477 bytes going out on a 1Gbps interface). +// +//Relation to NTP timestamps: abs_send_time_24 = (ntp_timestamp_64 >> 14) & 0x00ffffff ; NTP timestamp is 32 bits for whole seconds, 32 bits fraction of second. +// +//Notes: Packets are time stamped when going out, preferably close to metal. Intermediate RTP relays (entities possibly altering the stream) should remove the extension or set its own timestamp. +uint32_t RtpExt::getAbsSendTime() const { + CHECK(_type == RtpExtType::abs_send_time && size() >= 3); + uint32_t ret = 0; + ret |= (*this)[0] << 16; + ret |= (*this)[1] << 8; + ret |= (*this)[2]; + return ret; +} + +//https://tools.ietf.org/html/draft-holmer-rmcat-transport-wide-cc-extensions-01 +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | 0xBE | 0xDE | length=1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | L=1 |transport-wide sequence number | zero padding | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +uint16_t RtpExt::getTransportCCSeq() const { + CHECK(_type == RtpExtType::transport_cc && size() >= 2); + uint16_t ret; + ret = (*this)[0] << 8; + ret |= (*this)[1]; + return ret; +} + +//https://tools.ietf.org/html/draft-ietf-avtext-sdes-hdr-ext-07 +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len | SDES Item text value ... | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +string RtpExt::getSdesMid() const { + CHECK(_type == RtpExtType::sdes_mid && size() >= 1); + return *this; +} + + +//https://tools.ietf.org/html/draft-ietf-avtext-rid-06 +// 用于simulcast [AUTO-TRANSLATED:59b2682f] +// Used for simulcast +//3.1. RTCP 'RtpStreamId' SDES Extension +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |RtpStreamId=TBD| length | RtpStreamId ... +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// +// The RtpStreamId payload is UTF-8 encoded and is not null-terminated. +// +// RFC EDITOR NOTE: Please replace TBD with the assigned SDES +// identifier value. + +//3.2. RTCP 'RepairedRtpStreamId' SDES Extension +// +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Repaired...=TBD| length | RepairRtpStreamId ... +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +// +// The RepairedRtpStreamId payload is UTF-8 encoded and is not null- +// terminated. +// +// RFC EDITOR NOTE: Please replace TBD with the assigned SDES +// identifier value. + +string RtpExt::getRtpStreamId() const { + CHECK(_type == RtpExtType::sdes_rtp_stream_id && size() >= 1); + return *this; +} + +string RtpExt::getRepairedRtpStreamId() const { + CHECK(_type == RtpExtType::sdes_repaired_rtp_stream_id && size() >= 1); + return *this; +} + + +//http://www.webrtc.org/experiments/rtp-hdrext/video-timing +//Wire format: 1-byte extension, 13 bytes of data. Total 14 bytes extra per packet (plus 1-3 padding byte in some cases, plus shared 4 bytes for all extensions present: 2 byte magic word 0xBEDE, 2 byte # of extensions). +// +//First byte is a flags field. Defined flags: +// +//0x01 - extension is set due to timer. +//0x02 - extension is set because the frame is larger than usual. +//Both flags may be set at the same time. All remaining 6 bits are reserved and should be ignored. +// +//Next, 6 timestamps are stored as 16-bit values in big-endian order, representing delta from the capture time of a packet in ms. Timestamps are, in order: +// +//Encode start. +//Encode finish. +//Packetization complete. +//Last packet left the pacer. +//Reserved for network. +//Reserved for network (2). + +void RtpExt::getVideoTiming(uint8_t &flags, + uint16_t &encode_start, + uint16_t &encode_finish, + uint16_t &packetization_complete, + uint16_t &last_pkt_left_pacer, + uint16_t &reserved_net0, + uint16_t &reserved_net1) const { + CHECK(_type == RtpExtType::video_timing && size() >= 13); + flags = (*this)[0]; + encode_start = (*this)[1] << 8 | (*this)[2]; + encode_finish = (*this)[3] << 8 | (*this)[4]; + packetization_complete = (*this)[5] << 8 | (*this)[6]; + last_pkt_left_pacer = (*this)[7] << 8 | (*this)[8]; + reserved_net0 = (*this)[9] << 8 | (*this)[10]; + reserved_net1 = (*this)[11] << 8 | (*this)[12]; +} + + +//http://www.webrtc.org/experiments/rtp-hdrext/color-space +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | L = 3 | primaries | transfer | matrix | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |range+chr.sit. | +// +-+-+-+-+-+-+-+-+ + + +//http://www.webrtc.org/experiments/rtp-hdrext/video-content-type +//Values: +//0x00: Unspecified. Default value. Treated the same as an absence of an extension. +//0x01: Screenshare. Video stream is of a screenshare type. +// 0x02: 摄像头? [AUTO-TRANSLATED:ce2acbbb] +// 0x02: Camera? +//Notes: Extension shoud be present only in the last packet of key-frames. +// If attached to other packets it should be ignored. +// If extension is absent, Unspecified value is assumed. +uint8_t RtpExt::getVideoContentType() const { + CHECK(_type == RtpExtType::video_content_type && size() >= 1); + return (*this)[0]; +} + +//http://www.3gpp.org/ftp/Specs/html-info/26114.htm +void RtpExt::getVideoOrientation(bool &camera_bit, bool &flip_bit, bool &first_rotation, bool &second_rotation) const { + CHECK(_type == RtpExtType::video_orientation && size() >= 1); + uint8_t byte = (*this)[0]; + camera_bit = (byte & 0x08) >> 3; + flip_bit = (byte & 0x04) >> 2; + first_rotation = (byte & 0x02) >> 1; + second_rotation = byte & 0x01; +} + +//http://www.webrtc.org/experiments/rtp-hdrext/playout-delay +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +//+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +//| ID | len=2 | MIN delay | MAX delay | +//+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +void RtpExt::getPlayoutDelay(uint16_t &min_delay, uint16_t &max_delay) const { + CHECK(_type == RtpExtType::playout_delay && size() >= 3); + uint32_t bytes = (*this)[0] << 16 | (*this)[1] << 8 | (*this)[2]; + min_delay = (bytes & 0x00FFF000) >> 12; + max_delay = bytes & 0x00000FFF; +} + +//urn:ietf:params:rtp-hdrext:toffset +//https://tools.ietf.org/html/rfc5450 +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID | len=2 | transmission offset | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +uint32_t RtpExt::getTransmissionOffset() const { + CHECK(_type == RtpExtType::toffset && size() >= 3); + return (*this)[0] << 16 | (*this)[1] << 8 | (*this)[2]; +} + +//http://tools.ietf.org/html/draft-ietf-avtext-framemarking-07 +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | ID=? | L=2 |S|E|I|D|B| TID | LID | TL0PICIDX | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +uint8_t RtpExt::getFramemarkingTID() const { + CHECK(_type == RtpExtType::framemarking && size() >= 3); + return (*this)[0] & 0x07; +} + +void RtpExt::setExtId(uint8_t ext_id) { + assert(ext_id > (int) RtpExtType::padding && _ext); + if (_one_byte_ext) { + if (ext_id >= (int)RtpExtType::reserved) { + WarnL << "One byte rtp ext can not store id " << (int)ext_id << "(" << getExtName((RtpExtType)ext_id) << ") big than 14"; + clearExt(); + return; + } + auto ptr = reinterpret_cast(_ext); + ptr->setId(ext_id); + } else { + auto ptr = reinterpret_cast(_ext); + ptr->setId(ext_id); + } +} + +void RtpExt::clearExt(){ + assert(_ext); + if (_one_byte_ext) { + auto ptr = reinterpret_cast(_ext); + memset(ptr, (int) RtpExtType::padding, RtpExtOneByte::kMinSize + ptr->getSize()); + } else { + auto ptr = reinterpret_cast(_ext); + memset(ptr, (int) RtpExtType::padding, RtpExtTwoByte::kMinSize + ptr->getSize()); + } +} + +void RtpExt::setType(RtpExtType type) { + _type = type; +} + +RtpExtType RtpExt::getType() const { + return _type; +} + +RtpExt::operator bool() const { + return _ext != nullptr; +} + +RtpExtContext::RtpExtContext(const RtcMedia &m){ + for (auto &ext : m.extmap) { + auto ext_type = RtpExt::getExtType(ext.ext); + _rtp_ext_id_to_type.emplace(ext.id, ext_type); + _rtp_ext_type_to_id.emplace(ext_type, ext.id); + } +} + +string RtpExtContext::getRid(uint32_t ssrc) const{ + auto it = _ssrc_to_rid.find(ssrc); + if (it == _ssrc_to_rid.end()) { + return ""; + } + return it->second; +} + +void RtpExtContext::setRid(uint32_t ssrc, const string &rid) { + _ssrc_to_rid[ssrc] = rid; +} + +RtpExt RtpExtContext::changeRtpExtId(const RtpHeader *header, bool is_recv, string *rid_ptr, RtpExtType type) { + string rid, repaired_rid; + RtpExt ret; + auto ext_map = RtpExt::getExtValue(header); + for (auto &pr : ext_map) { + if (is_recv) { + auto it = _rtp_ext_id_to_type.find(pr.first); + if (it == _rtp_ext_id_to_type.end()) { + // TraceL << "接收rtp时,忽略不识别的rtp ext, id=" << (int) pr.first; [AUTO-TRANSLATED:284d8a38] + // TraceL << "Receiving rtp, ignoring unrecognized rtp ext, id=" << (int) pr.first; + pr.second.clearExt(); + continue; + } + pr.second.setType(it->second); + // 重新赋值ext id为 ext type,作为后面处理ext的统一中间类型 [AUTO-TRANSLATED:ab825878] + // Reassign ext id to ext type, as a unified intermediate type for processing ext later + pr.second.setExtId((uint8_t) it->second); + switch (it->second) { + case RtpExtType::sdes_rtp_stream_id : rid = pr.second.getRtpStreamId(); break; + case RtpExtType::sdes_repaired_rtp_stream_id : repaired_rid = pr.second.getRepairedRtpStreamId(); break; + default : break; + } + } else { + pr.second.setType((RtpExtType) pr.first); + auto it = _rtp_ext_type_to_id.find((RtpExtType) pr.first); + if (it == _rtp_ext_type_to_id.end()) { + // TraceL << "发送rtp时, 忽略不被客户端支持rtp ext:" << pr.second.dumpString(); [AUTO-TRANSLATED:5d9fd8cc] + // TraceL << "Sending rtp, ignoring rtp ext not supported by client:" << pr.second.dumpString(); + pr.second.clearExt(); + continue; + } + // 重新赋值ext id为客户端sdp声明的类型 [AUTO-TRANSLATED:06d60796] + // Reassign ext id to the type declared in client sdp + pr.second.setExtId(it->second); + } + if (pr.second.getType() == type) { + ret = pr.second; + } + } + + if (!is_recv) { + return ret; + } + if (rid.empty()) { + rid = repaired_rid; + } + auto ssrc = ntohl(header->ssrc); + if (rid.empty()) { + // 获取rid [AUTO-TRANSLATED:8ae4dffa] + // Get rid + rid = _ssrc_to_rid[ssrc]; + } else { + // 设置rid [AUTO-TRANSLATED:5e34819b] + // Set rid + auto it = _ssrc_to_rid.find(ssrc); + if (it == _ssrc_to_rid.end() || it->second != rid) { + _ssrc_to_rid[ssrc] = rid; + onGetRtp(header->pt, ssrc, rid); + } + } + if (rid_ptr) { + *rid_ptr = rid; + } + return ret; +} + +void RtpExtContext::setOnGetRtp(OnGetRtp cb) { + _cb = std::move(cb); +} + +void RtpExtContext::onGetRtp(uint8_t pt, uint32_t ssrc, const string &rid){ + if (_cb) { + _cb(pt, ssrc, rid); + } +} + +}// namespace mediakit diff --git a/webrtc/RtpMap.h b/webrtc/RtpMap.h new file mode 100644 index 00000000..ea9ebbbf --- /dev/null +++ b/webrtc/RtpMap.h @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_RTPMAP_H +#define ZLMEDIAKIT_RTPMAP_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include "Extension/Frame.h" + +namespace mediakit { + +class RtpMap { +public: + using Ptr = std::shared_ptr; + RtpMap(std::string code_name, uint8_t payload, uint32_t clock_rate) + : _code_name(std::move(code_name)) + , _payload(payload) + , _clock_rate(clock_rate) {} + virtual ~RtpMap() = default; + + virtual TrackType getType() = 0; + + const std::map &getFmtp() const { return _fmtp; } + + const std::string &getCodeName() const { return _code_name; } + uint8_t getPayload() const { return _payload; } + uint32_t getClockRate() const { return _clock_rate; } + +protected: + std::map _fmtp; + std::string _code_name; + uint8_t _payload; + uint32_t _clock_rate; +}; + +class VideoRtpMap : public RtpMap { +public: + VideoRtpMap(std::string code_name, uint8_t payload, uint32_t clock_rate) + : RtpMap(std::move(code_name), payload, clock_rate) {}; + + TrackType getType() override { return TrackVideo; } +}; + +class AudioRtpMap : public RtpMap { +public: + AudioRtpMap( std::string code_name, uint8_t payload, uint32_t clock_rate) + : RtpMap(std::move(code_name), payload, clock_rate) {}; + + TrackType getType() override { return TrackAudio; }; +}; + +#define H264_PROFILE_IDC_MAP(XX) \ + XX(PROFILE_H264_BASELINE, 66, "baseline") \ + XX(PROFILE_H264_MAIN, 77, "main") \ + XX(PROFILE_H264_HIGH, 100, "high") \ + XX(PROFILE_H264_HIGH10, 110, "high10") \ + XX(PROFILE_H264_HIGH422, 122, "high422") \ + XX(PROFILE_H264_HIGH444, 244, "high444") \ + +typedef enum { + H264ProfileIdcInvalid = -1, +#define XX(name, value, str) name = value, + H264_PROFILE_IDC_MAP(XX) +#undef XX + H264ProfileIdcMax +} H264ProfileIdc; + +#define H264_PROFILE_LEVEL_MAP(XX) \ + XX(10) \ + XX(20) \ + XX(30) \ + XX(31) \ + XX(40) \ + XX(41) \ + XX(50) \ + XX(51) + +typedef enum { + H264ProfileLevelInvalid = -1, +#define XX(value) H264_PROFILE_LEVEL_##value = value, + H264_PROFILE_LEVEL_MAP(XX) +#undef XX + H264ProfileLevelMax +} H264ProfileLevel; + +class H264RtpMap : public VideoRtpMap { +public: + H264RtpMap(uint8_t payload, uint32_t clock_rate, H264ProfileIdc profile_idc) + : VideoRtpMap("H264", payload, clock_rate) + , _profile_idc(profile_idc) { + _fmtp.emplace("level-asymmetry-allowed", "1"); + _fmtp.emplace("packetization-mode", "1"); + + toolkit::_StrPrinter printer; + printer << std::setw(2) << std::setfill('0') << std::hex << _profile_idc; + printer << std::setw(2) << std::setfill('0') << std::hex << _profile_iop; + printer << std::setw(2) << std::setfill('0') << std::hex << _profile_level; + _fmtp.emplace("profile-level-id", printer); + }; + +private: + H264ProfileIdc _profile_idc; + int _profile_iop = 0; + H264ProfileLevel _profile_level = H264_PROFILE_LEVEL_31; +}; + +#define H265_PROFILE_IDC_MAP(XX) \ + XX(PROFILE_H265_MAIN, 1, "main") \ + XX(PROFILE_H265_MAIN10, 2, "main10") \ + XX(PROFILE_H265_MAINSTILL, 3, "mainstill") \ + XX(PROFILE_H265_RANGE_EXTS, 4, "RangeExtensions") \ + XX(PROFILE_H265_HIGH_THROUGHPUT, 5, "HighThroughput") \ + XX(PROFILE_H265_MULTIVIEW, 6, "MultiviewMain") \ + XX(PROFILE_H265_SCALABLE_MAIN, 7, "ScalableMain") \ + XX(PROFILE_H265_3DMAIN, 8, "3dMain") \ + XX(PROFILE_H265_SCREEN, 9, "ScreenContentCoding") \ + XX(PROFILE_H265_SCALABLE_RANGE_EXTENSIONS, 10, "ScalableRangeExtensions") \ + XX(PROFILE_H265_HIGH_SCREEN, 11, "HighThroughputScreenContentCoding") + +typedef enum { + H265ProfileIdcInvalid = -1, +#define XX(name, value, str) name = value, + H265_PROFILE_IDC_MAP(XX) +#undef XX + H265ProfileIdcMax +} H265ProfileIdc; + +#define H265_PROFILE_LEVEL_MAP(XX) \ + XX(30) \ + XX(60) \ + XX(63) \ + XX(90) \ + XX(93) \ + XX(120) \ + XX(123) \ + XX(150) \ + XX(153) \ + XX(156) \ + XX(180) \ + XX(183) \ + XX(186) + +typedef enum { + H265ProfileLevelInvalid = -1, +#define XX(value) H265_PROFILE_LEVEL_##value = value, + H265_PROFILE_LEVEL_MAP(XX) +#undef XX + H265ProfileLevelMax +} H265ProfileLevel; + +class H265RtpMap : public VideoRtpMap { +public: + H265RtpMap(uint8_t payload, uint32_t clock_rate, H265ProfileIdc profile_idc) + : VideoRtpMap("H265", payload, clock_rate) + , _profile_idc(profile_idc) { + _fmtp.emplace("level-asymmetry-allowed", "1"); + _fmtp.emplace("packetization-mode", "1"); + + _fmtp.emplace("profile-id", std::to_string(_profile_idc)); + _fmtp.emplace("tier-flag", std::to_string(_tier_flag)); + _fmtp.emplace("level-id", std::to_string(_profile_level)); + } + +private: + H265ProfileIdc _profile_idc; + int _tier_flag = 0; // 0: main tier; 1: high tier + H265ProfileLevel _profile_level = H265_PROFILE_LEVEL_30; +}; + +class VP9RtpMap : public VideoRtpMap { +public: + VP9RtpMap(uint8_t payload, uint32_t clock_rate, int profile_id) + : VideoRtpMap("VP9", payload, clock_rate) + , _profile_id(profile_id) { + _fmtp.emplace("profile-id", std::to_string(_profile_id)); + }; + +private: + int _profile_id = 1; // 0-3 +}; + +class AV1RtpMap : public VideoRtpMap { +public: + AV1RtpMap(uint8_t payload, uint32_t clock_rate, int profile_id) + : VideoRtpMap("AV1", payload, clock_rate) + , _profile_id(profile_id) { + // a=fmtp:45 level-idx=5;profile=0;tier=0 + _fmtp.emplace("profile-id", std::to_string(_profile_id)); + }; + +private: + int _profile_id = 0; // 0-2 +}; +} // namespace mediakit + +#endif // ZLMEDIAKIT_RTPMAP_H diff --git a/webrtc/SctpAssociation.hpp b/webrtc/SctpAssociation.hpp index 9c46d275..715c68a9 100644 --- a/webrtc/SctpAssociation.hpp +++ b/webrtc/SctpAssociation.hpp @@ -3,7 +3,7 @@ #ifdef ENABLE_SCTP #include -#include "Utils.hpp" +#include "Util/Byte.hpp" #include "Poller/EventPoller.h" namespace RTC @@ -62,8 +62,8 @@ namespace RTC return ( (len >= 12) && // Must have Source Port Number and Destination Port Number set to 5000 (hack). - (Utils::Byte::Get2Bytes(data, 0) == 5000) && - (Utils::Byte::Get2Bytes(data, 2) == 5000) + (toolkit::Byte::Get2Bytes(data, 0) == 5000) && + (toolkit::Byte::Get2Bytes(data, 2) == 5000) ); // clang-format on } diff --git a/webrtc/Sdp.cpp b/webrtc/Sdp.cpp index b23969c1..a6412e0b 100644 --- a/webrtc/Sdp.cpp +++ b/webrtc/Sdp.cpp @@ -1,1894 +1,2079 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#include "Sdp.h" -#include "Rtsp/Rtsp.h" -#include "Common/config.h" -#include - -using namespace std; -using namespace toolkit; - -namespace mediakit { - -namespace Rtc { -#define RTC_FIELD "rtc." -const string kPreferredCodecA = RTC_FIELD "preferredCodecA"; -const string kPreferredCodecV = RTC_FIELD "preferredCodecV"; -static onceToken token([]() { - mINI::Instance()[kPreferredCodecA] = "PCMA,PCMU,opus,mpeg4-generic"; - mINI::Instance()[kPreferredCodecV] = "H264,H265,AV1,VP9,VP8"; -}); -} // namespace Rtc - -using onCreateSdpItem = function; -static map sdpItemCreator; - -template -void registerSdpItem() { - onCreateSdpItem func = [](const string &key, const string &value) { - auto ret = std::make_shared(); - ret->parse(value); - return ret; - }; - Item item; - sdpItemCreator.emplace(item.getKey(), std::move(func)); -} - -class DirectionInterface { -public: - virtual RtpDirection getDirection() const = 0; -}; - -class SdpDirectionSendonly : public SdpItem, public DirectionInterface { -public: - const char *getKey() const override { return getRtpDirectionString(getDirection()); } - RtpDirection getDirection() const override { return RtpDirection::sendonly; } -}; - -class SdpDirectionRecvonly : public SdpItem, public DirectionInterface { -public: - const char *getKey() const override { return getRtpDirectionString(getDirection()); } - RtpDirection getDirection() const override { return RtpDirection::recvonly; } -}; - -class SdpDirectionSendrecv : public SdpItem, public DirectionInterface { -public: - const char *getKey() const override { return getRtpDirectionString(getDirection()); } - RtpDirection getDirection() const override { return RtpDirection::sendrecv; } -}; - -class SdpDirectionInactive : public SdpItem, public DirectionInterface { -public: - const char *getKey() const override { return getRtpDirectionString(getDirection()); } - RtpDirection getDirection() const override { return RtpDirection::inactive; } -}; - -class DirectionInterfaceImp : public SdpItem, public DirectionInterface { -public: - DirectionInterfaceImp(RtpDirection direct) { direction = direct; } - const char *getKey() const override { return getRtpDirectionString(getDirection()); } - RtpDirection getDirection() const override { return direction; } - -private: - RtpDirection direction; -}; - -static bool registerAllItem() { - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem>(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - registerSdpItem(); - return true; -} - -static map dtls_role_map = { - {"active", DtlsRole::active}, - {"passive", DtlsRole::passive}, - {"actpass", DtlsRole::actpass} -}; - -DtlsRole getDtlsRole(const string &str) { - auto it = dtls_role_map.find(str); - return it == dtls_role_map.end() ? DtlsRole::invalid : it->second; -} - -const char *getDtlsRoleString(DtlsRole role) { - switch (role) { - case DtlsRole::active: return "active"; - case DtlsRole::passive: return "passive"; - case DtlsRole::actpass: return "actpass"; - default: return "invalid"; - } -} - -static map direction_map = { - {"sendonly", RtpDirection::sendonly}, - {"recvonly", RtpDirection::recvonly}, - {"sendrecv", RtpDirection::sendrecv}, - {"inactive", RtpDirection::inactive} -}; - -RtpDirection getRtpDirection(const string &str) { - auto it = direction_map.find(str); - return it == direction_map.end() ? RtpDirection::invalid : it->second; -} - -const char *getRtpDirectionString(RtpDirection val) { - switch (val) { - case RtpDirection::sendonly: return "sendonly"; - case RtpDirection::recvonly: return "recvonly"; - case RtpDirection::sendrecv: return "sendrecv"; - case RtpDirection::inactive: return "inactive"; - default: return "invalid"; - } -} - -////////////////////////////////////////////////////////////////////////////////////////// - -string RtcSdpBase::toString() const { - _StrPrinter printer; - for (auto &item : items) { - printer << item->getKey() << "=" << item->toString() << "\r\n"; - } - return std::move(printer); -} - -RtpDirection RtcSdpBase::getDirection() const { - for (auto &item : items) { - auto attr = dynamic_pointer_cast(item); - if (attr) { - auto dir = dynamic_pointer_cast(attr->detail); - if (dir) { - return dir->getDirection(); - } - } - } - return RtpDirection::invalid; -} - -SdpItem::Ptr RtcSdpBase::getItem(char key_c, const char *attr_key) const { - std::string key(1, key_c); - for (auto item : items) { - if (strcasecmp(item->getKey(), key.data()) == 0) { - if (!attr_key) { - return item; - } - auto attr = dynamic_pointer_cast(item); - if (attr && !strcasecmp(attr->detail->getKey(), attr_key)) { - return attr->detail; - } - } - } - return SdpItem::Ptr(); -} - -////////////////////////////////////////////////////////////////////////// -int RtcSessionSdp::getVersion() const { - return atoi(getStringItem('v').data()); -} - -SdpOrigin RtcSessionSdp::getOrigin() const { - return getItemClass('o'); -} - -string RtcSessionSdp::getSessionName() const { - return getStringItem('s'); -} - -string RtcSessionSdp::getSessionInfo() const { - return getStringItem('i'); -} - -SdpTime RtcSessionSdp::getSessionTime() const { - return getItemClass('t'); -} - -SdpConnection RtcSessionSdp::getConnection() const { - return getItemClass('c'); -} - -SdpBandwidth RtcSessionSdp::getBandwidth() const { - return getItemClass('b'); -} - -string RtcSessionSdp::getUri() const { - return getStringItem('u'); -} - -string RtcSessionSdp::getEmail() const { - return getStringItem('e'); -} - -string RtcSessionSdp::getPhone() const { - return getStringItem('p'); -} - -string RtcSessionSdp::getTimeZone() const { - return getStringItem('z'); -} - -string RtcSessionSdp::getEncryptKey() const { - return getStringItem('k'); -} - -string RtcSessionSdp::getRepeatTimes() const { - return getStringItem('r'); -} - -////////////////////////////////////////////////////////////////////// - -void RtcSessionSdp::parse(const string &str) { - static auto flag = registerAllItem(); - RtcSdpBase *media = nullptr; - auto lines = split(str, "\n"); - std::set line_set; - for (auto &line : lines) { - trim(line); - if (line.size() < 3 || line[1] != '=') { - continue; - } - - if (!line_set.emplace(line).second) { - continue; - } - - auto key = line.substr(0, 1); - auto value = line.substr(2); - if (!strcasecmp(key.data(), "m")) { - medias.emplace_back(RtcSdpBase()); - media = &medias.back(); - line_set.clear(); - } - - SdpItem::Ptr item; - auto it = sdpItemCreator.find(key); - if (it != sdpItemCreator.end()) { - item = it->second(key, value); - } else { - item = std::make_shared(key); - item->parse(value); - } - if (media) { - media->addItem(std::move(item)); - } else { - addItem(std::move(item)); - } - } -} - -string RtcSessionSdp::toString() const { - _StrPrinter printer; - printer << RtcSdpBase::toString(); - for (auto &media : medias) { - printer << media.toString(); - } - - return std::move(printer); -} - -////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_SDP(exp) CHECK(exp, "解析sdp ", getKey(), " 字段失败:", str) - -void SdpTime::parse(const string &str) { - CHECK_SDP(sscanf(str.data(), "%" SCNu64 " %" SCNu64, &start, &stop) == 2); -} - -string SdpTime::toString() const { - if (value.empty()) { - value = to_string(start) + " " + to_string(stop); - } - return SdpItem::toString(); -} - -void SdpOrigin::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() == 6); - username = vec[0]; - session_id = vec[1]; - session_version = vec[2]; - nettype = vec[3]; - addrtype = vec[4]; - address = vec[5]; -} - -string SdpOrigin::toString() const { - if (value.empty()) { - value = username + " " + session_id + " " + session_version + " " + nettype + " " + addrtype + " " + address; - } - return SdpItem::toString(); -} - -void SdpConnection::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() == 3); - nettype = vec[0]; - addrtype = vec[1]; - address = vec[2]; -} - -string SdpConnection::toString() const { - if (value.empty()) { - value = nettype + " " + addrtype + " " + address; - } - return SdpItem::toString(); -} - -void SdpBandwidth::parse(const string &str) { - auto vec = split(str, ":"); - CHECK_SDP(vec.size() == 2); - bwtype = vec[0]; - bandwidth = atoi(vec[1].data()); -} - -string SdpBandwidth::toString() const { - if (value.empty()) { - value = bwtype + ":" + to_string(bandwidth); - } - return SdpItem::toString(); -} - -void SdpMedia::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() >= 4); - type = getTrackType(vec[0]); - CHECK_SDP(type != TrackInvalid); - port = atoi(vec[1].data()); - proto = vec[2]; - for (size_t i = 3; i < vec.size(); ++i) { - fmts.emplace_back(vec[i]); - } -} - -string SdpMedia::toString() const { - if (value.empty()) { - value = string(getTrackString(type)) + " " + to_string(port) + " " + proto; - for (auto fmt : fmts) { - value += ' '; - value += fmt; - } - } - return SdpItem::toString(); -} - -void SdpAttr::parse(const string &str) { - auto pos = str.find(':'); - auto key = pos == string::npos ? str : str.substr(0, pos); - auto value = pos == string::npos ? string() : str.substr(pos + 1); - auto it = sdpItemCreator.find(key); - if (it != sdpItemCreator.end()) { - detail = it->second(key, value); - } else { - detail = std::make_shared(key); - detail->parse(value); - } -} - -string SdpAttr::toString() const { - if (value.empty()) { - auto detail_value = detail->toString(); - if (detail_value.empty()) { - value = detail->getKey(); - } else { - value = string(detail->getKey()) + ":" + detail_value; - } - } - return SdpItem::toString(); -} - -void SdpAttrGroup::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() >= 2); - type = vec[0]; - vec.erase(vec.begin()); - mids = std::move(vec); -} - -string SdpAttrGroup::toString() const { - if (value.empty()) { - value = type; - for (auto mid : mids) { - value += ' '; - value += mid; - } - } - return SdpItem::toString(); -} - -void SdpAttrMsidSemantic::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() >= 1); - msid = vec[0]; - token = vec.size() > 1 ? vec[1] : ""; -} - -string SdpAttrMsidSemantic::toString() const { - if (value.empty()) { - if (token.empty()) { - value = string(" ") + msid; - } else { - value = string(" ") + msid + " " + token; - } - } - return SdpItem::toString(); -} - -void SdpAttrRtcp::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() == 4); - port = atoi(vec[0].data()); - nettype = vec[1]; - addrtype = vec[2]; - address = vec[3]; -} - -string SdpAttrRtcp::toString() const { - if (value.empty()) { - value = to_string(port) + " " + nettype + " " + addrtype + " " + address; - } - return SdpItem::toString(); -} - -void SdpAttrIceOption::parse(const string &str) { - auto vec = split(str, " "); - for (auto &v : vec) { - if (!strcasecmp(v.data(), "trickle")) { - trickle = true; - continue; - } - if (!strcasecmp(v.data(), "renomination")) { - renomination = true; - continue; - } - } -} - -string SdpAttrIceOption::toString() const { - if (value.empty()) { - if (trickle && renomination) { - value = "trickle renomination"; - } else if (trickle) { - value = "trickle"; - } else if (renomination) { - value = "renomination"; - } - } - return value; -} - -void SdpAttrFingerprint::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() == 2); - algorithm = vec[0]; - hash = vec[1]; -} - -string SdpAttrFingerprint::toString() const { - if (value.empty()) { - value = algorithm + " " + hash; - } - return SdpItem::toString(); -} - -void SdpAttrSetup::parse(const string &str) { - role = getDtlsRole(str); - CHECK_SDP(role != DtlsRole::invalid); -} - -string SdpAttrSetup::toString() const { - if (value.empty()) { - value = getDtlsRoleString(role); - } - return SdpItem::toString(); -} - -void SdpAttrExtmap::parse(const string &str) { - char buf[128] = { 0 }; - char direction_buf[32] = { 0 }; - if (sscanf(str.data(), "%" SCNd8 "/%31[^ ] %127s", &id, direction_buf, buf) != 3) { - CHECK_SDP(sscanf(str.data(), "%" SCNd8 " %127s", &id, buf) == 2); - direction = RtpDirection::sendrecv; - } else { - direction = getRtpDirection(direction_buf); - } - ext = buf; -} - -string SdpAttrExtmap::toString() const { - if (value.empty()) { - if (direction == RtpDirection::invalid || direction == RtpDirection::sendrecv) { - value = to_string((int)id) + " " + ext; - } else { - value = to_string((int)id) + "/" + getRtpDirectionString(direction) + " " + ext; - } - } - return SdpItem::toString(); -} - -void SdpAttrRtpMap::parse(const string &str) { - char buf[32] = { 0 }; - if (sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32 "/%" SCNd32, &pt, buf, &sample_rate, &channel) != 4) { - CHECK_SDP(sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32, &pt, buf, &sample_rate) == 3); - if (getTrackType(getCodecId(buf)) == TrackAudio) { - // 未指定通道数时,且为音频时,那么通道数默认为1 [AUTO-TRANSLATED:bd128fbd] - // If the number of channels is not specified and it is audio, the number of channels defaults to 1 - channel = 1; - } - } - codec = buf; -} - -string SdpAttrRtpMap::toString() const { - if (value.empty()) { - value = to_string((int)pt) + " " + codec + "/" + to_string(sample_rate); - if (channel) { - value += '/'; - value += to_string(channel); - } - } - return SdpItem::toString(); -} - -void SdpAttrRtcpFb::parse(const string &str_in) { - auto str = str_in + "\n"; - char rtcp_type_buf[32] = { 0 }; - CHECK_SDP(sscanf(str.data(), "%" SCNu8 " %31[^\n]", &pt, rtcp_type_buf) == 2); - rtcp_type = rtcp_type_buf; -} - -string SdpAttrRtcpFb::toString() const { - if (value.empty()) { - value = to_string((int)pt) + " " + rtcp_type; - } - return SdpItem::toString(); -} - -void SdpAttrFmtp::parse(const string &str) { - auto pos = str.find(' '); - CHECK_SDP(pos != string::npos); - pt = atoi(str.substr(0, pos).data()); - auto vec = split(str.substr(pos + 1), ";"); - for (auto &item : vec) { - trim(item); - auto pos = item.find('='); - if (pos == string::npos) { - fmtp.emplace(std::make_pair(item, "")); - } else { - fmtp.emplace(std::make_pair(item.substr(0, pos), item.substr(pos + 1))); - } - } - CHECK_SDP(!fmtp.empty()); -} - -string SdpAttrFmtp::toString() const { - if (value.empty()) { - value = to_string((int)pt); - int i = 0; - for (auto &pr : fmtp) { - value += (i++ ? ';' : ' '); - value += pr.first + "=" + pr.second; - } - } - return SdpItem::toString(); -} - -void SdpAttrSSRC::parse(const string &str_in) { - auto str = str_in + '\n'; - char attr_buf[32] = { 0 }; - char attr_val_buf[128] = { 0 }; - if (3 == sscanf(str.data(), "%" SCNu32 " %31[^:]:%127[^\n]", &ssrc, attr_buf, attr_val_buf)) { - attribute = attr_buf; - attribute_value = attr_val_buf; - } else if (2 == sscanf(str.data(), "%" SCNu32 " %31s[^\n]", &ssrc, attr_buf)) { - attribute = attr_buf; - } else { - CHECK_SDP(0); - } -} - -string SdpAttrSSRC::toString() const { - if (value.empty()) { - value = to_string(ssrc) + ' '; - value += attribute; - if (!attribute_value.empty()) { - value += ':'; - value += attribute_value; - } - } - return SdpItem::toString(); -} - -void SdpAttrSSRCGroup::parse(const string &str) { - auto vec = split(str, " "); - CHECK_SDP(vec.size() >= 3); - type = std::move(vec[0]); - CHECK(isFID() || isSIM()); - vec.erase(vec.begin()); - for (auto ssrc : vec) { - ssrcs.emplace_back((uint32_t)atoll(ssrc.data())); - } -} - -string SdpAttrSSRCGroup::toString() const { - if (value.empty()) { - value = type; - // 最少要求2个ssrc [AUTO-TRANSLATED:968acb83] - // At least 2 SSRCs are required - CHECK(ssrcs.size() >= 2); - for (auto &ssrc : ssrcs) { - value += ' '; - value += to_string(ssrc); - } - } - return SdpItem::toString(); -} - -void SdpAttrSctpMap::parse(const string &str) { - char subtypes_buf[64] = { 0 }; - CHECK_SDP(3 == sscanf(str.data(), "%" SCNu16 " %63[^ ] %" SCNd32, &port, subtypes_buf, &streams)); - subtypes = subtypes_buf; -} - -string SdpAttrSctpMap::toString() const { - if (value.empty()) { - value = to_string(port); - value += ' '; - value += subtypes; - value += ' '; - value += to_string(streams); - } - return SdpItem::toString(); -} - -void SdpAttrCandidate::parse(const string &str) { - char foundation_buf[40] = { 0 }; - char transport_buf[16] = { 0 }; - char address_buf[64] = { 0 }; - char type_buf[16] = { 0 }; - - // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 - CHECK_SDP(sscanf(str.data(), "%32[^ ] %" SCNu32 " %15[^ ] %" SCNu32 " %63[^ ] %" SCNu16 " typ %15[^ ]", - foundation_buf, &component, transport_buf, &priority, address_buf, &port, type_buf) == 7); - foundation = foundation_buf; - transport = transport_buf; - address = address_buf; - type = type_buf; - auto pos = str.find(type); - if (pos != string::npos) { - auto remain = str.substr(pos + type.size()); - trim(remain); - if (!remain.empty()) { - auto vec = split(remain, " "); - string key; - for (auto &item : vec) { - if (key.empty()) { - key = item; - } else { - arr.emplace_back(std::make_pair(std::move(key), std::move(item))); - } - } - } - } -} - -string SdpAttrCandidate::toString() const { - if (value.empty()) { - value = foundation + " " + to_string(component) + " " + transport + " " + to_string(priority) + " " + address + " " + to_string(port) + " typ " + type; - for (auto &pr : arr) { - value += ' '; - value += pr.first; - value += ' '; - value += pr.second; - } - } - return SdpItem::toString(); -} - -void SdpAttrSimulcast::parse(const string &str) { - // https://www.meetecho.com/blog/simulcast-janus-ssrc/ - // a=simulcast:send/recv q;h;f - // a=simulcast:send/recv [rid=]q;h;f - // a=simulcast: recv h;m;l - // - auto vec = split(str, " "); - CHECK_SDP(vec.size() == 2); - direction = vec[0]; - rids = split(vec[1], ";"); -} - -string SdpAttrSimulcast::toString() const { - if (value.empty()) { - value = direction + " "; - bool first = true; - for (auto &rid : rids) { - if (first) { - first = false; - } else { - value += ';'; - } - value += rid; - } - } - return SdpItem::toString(); -} - -void SdpAttrRid::parse(const string &str) { - auto vec = split(str, " "); - CHECK(vec.size() >= 2); - rid = vec[0]; - direction = vec[1]; -} - -string SdpAttrRid::toString() const { - if (value.empty()) { - value = rid + " " + direction; - } - return SdpItem::toString(); -} - -void RtcSession::loadFrom(const string &str) { - RtcSessionSdp sdp; - sdp.parse(str); - - version = sdp.getVersion(); - origin = sdp.getOrigin(); - session_name = sdp.getSessionName(); - session_info = sdp.getSessionInfo(); - connection = sdp.getConnection(); - time = sdp.getSessionTime(); - msid_semantic = sdp.getItemClass('a', "msid-semantic"); - for (auto &media : sdp.medias) { - auto mline = media.getItemClass('m'); - this->media.emplace_back(); - auto &rtc_media = this->media.back(); - rtc_media.mid = media.getStringItem('a', "mid"); - rtc_media.proto = mline.proto; - rtc_media.type = mline.type; - rtc_media.port = mline.port; - rtc_media.addr = media.getItemClass('c'); - rtc_media.bandwidth = media.getItemClass('b'); - rtc_media.ice_ufrag = media.getStringItem('a', "ice-ufrag"); - rtc_media.ice_pwd = media.getStringItem('a', "ice-pwd"); - rtc_media.role = media.getItemClass('a', "setup").role; - rtc_media.fingerprint = media.getItemClass('a', "fingerprint"); - if (rtc_media.fingerprint.empty()) { - rtc_media.fingerprint = sdp.getItemClass('a', "fingerprint"); - } - rtc_media.ice_lite = media.getItem('a', "ice-lite").operator bool(); - auto ice_options = media.getItemClass('a', "ice-options"); - rtc_media.ice_trickle = ice_options.trickle; - rtc_media.ice_renomination = ice_options.renomination; - rtc_media.candidate = media.getAllItem('a', "candidate"); - - if (mline.type == TrackType::TrackApplication) { - rtc_media.sctp_port = atoi(media.getStringItem('a', "sctp-port").data()); - rtc_media.sctpmap = media.getItemClass('a', "sctpmap"); - continue; - } - rtc_media.rtcp_addr = media.getItemClass('a', "rtcp"); - rtc_media.direction = media.getDirection(); - rtc_media.extmap = media.getAllItem('a', "extmap"); - rtc_media.rtcp_mux = media.getItem('a', "rtcp-mux").operator bool(); - rtc_media.rtcp_rsize = media.getItem('a', "rtcp-rsize").operator bool(); - - map rtc_ssrc_map; - auto ssrc_attr = media.getAllItem('a', "ssrc"); - for (auto &ssrc : ssrc_attr) { - auto &rtc_ssrc = rtc_ssrc_map[ssrc.ssrc]; - rtc_ssrc.ssrc = ssrc.ssrc; - if (!strcasecmp(ssrc.attribute.data(), "cname")) { - rtc_ssrc.cname = ssrc.attribute_value; - continue; - } - if (!strcasecmp(ssrc.attribute.data(), "msid")) { - rtc_ssrc.msid = ssrc.attribute_value; - continue; - } - if (!strcasecmp(ssrc.attribute.data(), "mslabel")) { - rtc_ssrc.mslabel = ssrc.attribute_value; - continue; - } - if (!strcasecmp(ssrc.attribute.data(), "label")) { - rtc_ssrc.label = ssrc.attribute_value; - continue; - } - } - - auto ssrc_groups = media.getAllItem('a', "ssrc-group"); - bool have_rtx_ssrc = false; - SdpAttrSSRCGroup *ssrc_group_sim = nullptr; - for (auto &group : ssrc_groups) { - if (group.isFID()) { - have_rtx_ssrc = true; - // ssrc-group:FID字段必须包含rtp/rtx的ssrc [AUTO-TRANSLATED:3da97d7d] - // The ssrc-group:FID field must contain the SSRCs of rtp/rtx - CHECK(group.ssrcs.size() == 2); - // 根据rtp ssrc找到对象 [AUTO-TRANSLATED:c0a56b42] - // Find the object based on the RTP SSRC - auto it = rtc_ssrc_map.find(group.ssrcs[0]); - CHECK(it != rtc_ssrc_map.end()); - // 设置rtx ssrc [AUTO-TRANSLATED:422e2a55] - // Set the RTX SSRC - it->second.rtx_ssrc = group.ssrcs[1]; - rtc_media.rtp_rtx_ssrc.emplace_back(it->second); - } else if (group.isSIM()) { - CHECK(!ssrc_group_sim); - ssrc_group_sim = &group; - } - } - - if (!have_rtx_ssrc) { - // 按照sdp顺序依次添加ssrc [AUTO-TRANSLATED:0996ba7e] - // Add SSRCs in the order of SDP - for (auto &attr : ssrc_attr) { - if (attr.attribute == "cname") { - rtc_media.rtp_rtx_ssrc.emplace_back(rtc_ssrc_map[attr.ssrc]); - } - } - } - - auto simulcast = media.getItemClass('a', "simulcast"); - if (!simulcast.empty()) { - // a=rid:h send - // a=rid:m send - // a=rid:l send - // a=simulcast:send h;m;l - // 风格的simulcast [AUTO-TRANSLATED:94ac2d55] - // Style of simulcast - unordered_set rid_map; - for (auto &rid : simulcast.rids) { - rid_map.emplace(rid); - } - for (auto &rid : media.getAllItem('a', "rid")) { - CHECK(rid.direction == simulcast.direction); - CHECK(rid_map.find(rid.rid) != rid_map.end()); - } - // simulcast最少要求2种方案 [AUTO-TRANSLATED:31732a7a] - // Simulcast requires at least 2 schemes - CHECK(simulcast.rids.size() >= 2); - rtc_media.rtp_rids = simulcast.rids; - } - - if (ssrc_group_sim) { - // 指定了a=ssrc-group:SIM [AUTO-TRANSLATED:5732661e] - // a=ssrc-group:SIM is specified - for (auto ssrc : ssrc_group_sim->ssrcs) { - auto it = rtc_ssrc_map.find(ssrc); - CHECK(it != rtc_ssrc_map.end()); - rtc_media.rtp_ssrc_sim.emplace_back(it->second); - } - } else if (!rtc_media.rtp_rids.empty()) { - // 未指定a=ssrc-group:SIM, 但是指定了a=simulcast, 那么只能根据ssrc顺序来对应rid顺序 [AUTO-TRANSLATED:b198a817] - // a=ssrc-group:SIM is not specified, but a=simulcast is specified, so the RID order can only be matched according to the SSRC order - rtc_media.rtp_ssrc_sim = rtc_media.rtp_rtx_ssrc; - } - - if (!rtc_media.supportSimulcast()) { - // 不支持simulcast的情况下,最多一组ssrc [AUTO-TRANSLATED:3ea8ed65] - // In the case of not supporting simulcast, there is at most one group of SSRCs - CHECK(rtc_media.rtp_rtx_ssrc.size() <= 1); - } else { - // simulcast的情况下,要么没有指定ssrc,要么指定的ssrc个数与rid个数一致 [AUTO-TRANSLATED:1d45ce03] - // In the case of simulcast, either no SSRC is specified or the number of specified SSRCs is consistent with the number of RIDs - // CHECK(rtc_media.rtp_ssrc_sim.empty() || rtc_media.rtp_ssrc_sim.size() == rtc_media.rtp_rids.size()); - } - - auto rtpmap_arr = media.getAllItem('a', "rtpmap"); - auto rtcpfb_arr = media.getAllItem('a', "rtcp-fb"); - auto fmtp_aar = media.getAllItem('a', "fmtp"); - // 方便根据pt查找rtpmap,一个pt必有一条 [AUTO-TRANSLATED:c3673faa] - // Convenient to find rtpmap based on pt, one pt must have one - map rtpmap_map; - // 方便根据pt查找rtcp-fb,一个pt可能有多条或0条 [AUTO-TRANSLATED:38361f68] - // Convenient to find rtcp-fb based on pt, one pt may have multiple or 0 - multimap rtcpfb_map; - // 方便根据pt查找fmtp,一个pt最多一条 [AUTO-TRANSLATED:be5d652d] - // Convenient to find fmtp based on pt, one pt has at most one - map fmtp_map; - - for (auto &rtpmap : rtpmap_arr) { - // 添加失败,有多条 [AUTO-TRANSLATED:717782c0] - // Add failed, there are multiple - CHECK(rtpmap_map.emplace(rtpmap.pt, rtpmap).second, "该pt存在多条a=rtpmap:", (int)rtpmap.pt); - } - for (auto &rtpfb : rtcpfb_arr) { - rtcpfb_map.emplace(rtpfb.pt, rtpfb); - } - for (auto &fmtp : fmtp_aar) { - // 添加失败,有多条 [AUTO-TRANSLATED:717782c0] - // Add failed, there are multiple - CHECK(fmtp_map.emplace(fmtp.pt, fmtp).second, "该pt存在多条a=fmtp:", (int)fmtp.pt); - } - for (auto &item : mline.fmts) { - auto pt = atoi(item.c_str()); - CHECK(pt < 0xFF, "invalid payload type: ", item); - // 遍历所有编码方案的pt [AUTO-TRANSLATED:40f2db36] - // Traverse the pt of all encoding schemes - rtc_media.plan.emplace_back(); - auto &plan = rtc_media.plan.back(); - auto rtpmap_it = rtpmap_map.find(pt); - if (rtpmap_it == rtpmap_map.end()) { - plan.pt = pt; - plan.codec = RtpPayload::getName(pt); - plan.sample_rate = RtpPayload::getClockRate(pt); - plan.channel = RtpPayload::getAudioChannel(pt); - } else { - plan.pt = rtpmap_it->second.pt; - plan.codec = rtpmap_it->second.codec; - plan.sample_rate = rtpmap_it->second.sample_rate; - plan.channel = rtpmap_it->second.channel; - } - - auto fmtp_it = fmtp_map.find(pt); - if (fmtp_it != fmtp_map.end()) { - plan.fmtp = fmtp_it->second.fmtp; - } - for (auto rtpfb_it = rtcpfb_map.find(pt); rtpfb_it != rtcpfb_map.end() && rtpfb_it->second.pt == pt; ++rtpfb_it) { - plan.rtcp_fb.emplace(rtpfb_it->second.rtcp_type); - } - } - } - - group = sdp.getItemClass('a', "group"); -} - -void RtcSdpBase::toRtsp() { - for (auto it = items.begin(); it != items.end();) { - switch ((*it)->getKey()[0]) { - case 'v': - case 'o': - case 's': - case 'i': - case 't': - case 'c': - case 'b': { - ++it; - break; - } - - case 'm': { - auto m = dynamic_pointer_cast(*it); - CHECK(m); - m->proto = "RTP/AVP"; - ++it; - break; - } - case 'a': { - auto attr = dynamic_pointer_cast(*it); - CHECK(attr); - if (!strcasecmp(attr->detail->getKey(), "rtpmap") || !strcasecmp(attr->detail->getKey(), "fmtp")) { - ++it; - break; - } - } - default: { - it = items.erase(it); - break; - } - } - } -} - -string RtcSession::toRtspSdp() const { - RtcSession copy = *this; - copy.media.clear(); - for (auto &m : media) { - switch (m.type) { - case TrackAudio: - case TrackVideo: { - if (m.direction != RtpDirection::inactive) { - copy.media.emplace_back(m); - copy.media.back().plan.resize(1); - } - break; - } - default: continue; - } - } - - CHECK(!copy.media.empty()); - auto sdp = copy.toRtcSessionSdp(); - sdp->toRtsp(); - int i = 0; - for (auto &m : sdp->medias) { - m.toRtsp(); - m.addAttr(std::make_shared("control", string("trackID=") + to_string(i++))); - } - return sdp->toString(); -} - -void addSdpAttrSSRC(const RtcSSRC &rtp_ssrc, RtcSdpBase &media, uint32_t ssrc_num) { - assert(ssrc_num); - SdpAttrSSRC ssrc; - ssrc.ssrc = ssrc_num; - - ssrc.attribute = "cname"; - ssrc.attribute_value = rtp_ssrc.cname; - media.addAttr(std::make_shared(ssrc)); - - if (!rtp_ssrc.msid.empty()) { - ssrc.attribute = "msid"; - ssrc.attribute_value = rtp_ssrc.msid; - media.addAttr(std::make_shared(ssrc)); - } - - if (!rtp_ssrc.mslabel.empty()) { - ssrc.attribute = "mslabel"; - ssrc.attribute_value = rtp_ssrc.mslabel; - media.addAttr(std::make_shared(ssrc)); - } - - if (!rtp_ssrc.label.empty()) { - ssrc.attribute = "label"; - ssrc.attribute_value = rtp_ssrc.label; - media.addAttr(std::make_shared(ssrc)); - } -} - -RtcSessionSdp::Ptr RtcSession::toRtcSessionSdp() const { - RtcSessionSdp::Ptr ret = std::make_shared(); - auto &sdp = *ret; - sdp.addItem(std::make_shared>(to_string(version))); - sdp.addItem(std::make_shared(origin)); - sdp.addItem(std::make_shared>(session_name)); - if (!session_info.empty()) { - sdp.addItem(std::make_shared>(session_info)); - } - sdp.addItem(std::make_shared(time)); - if (connection.empty()) { - sdp.addItem(std::make_shared(connection)); - } - sdp.addAttr(std::make_shared(group)); - sdp.addAttr(std::make_shared(msid_semantic)); - - bool ice_lite = false; - - for (auto &m : media) { - sdp.medias.emplace_back(); - auto &sdp_media = sdp.medias.back(); - auto mline = std::make_shared(); - mline->type = m.type; - mline->port = m.port; - mline->proto = m.proto; - for (auto &p : m.plan) { - mline->fmts.emplace_back(to_string((int)p.pt)); - } - if (m.type == TrackApplication) { - mline->fmts.emplace_back("webrtc-datachannel"); - } - sdp_media.addItem(std::move(mline)); - sdp_media.addItem(std::make_shared(m.addr)); - if (!m.bandwidth.empty() && m.type != TrackAudio) { - sdp_media.addItem(std::make_shared(m.bandwidth)); - } - if (!m.rtcp_addr.empty()) { - sdp_media.addAttr(std::make_shared(m.rtcp_addr)); - } - - sdp_media.addAttr(std::make_shared(m.ice_ufrag)); - sdp_media.addAttr(std::make_shared(m.ice_pwd)); - if (m.ice_trickle || m.ice_renomination) { - auto attr = std::make_shared(); - attr->trickle = m.ice_trickle; - attr->renomination = m.ice_renomination; - sdp_media.addAttr(attr); - } - sdp_media.addAttr(std::make_shared(m.fingerprint)); - sdp_media.addAttr(std::make_shared(m.role)); - sdp_media.addAttr(std::make_shared(m.mid)); - if (m.ice_lite) { - sdp_media.addAttr(std::make_shared("ice-lite")); - ice_lite = true; - } - for (auto &ext : m.extmap) { - sdp_media.addAttr(std::make_shared(ext)); - } - if (m.direction != RtpDirection::invalid) { - sdp_media.addAttr(std::make_shared(m.direction)); - } - if (m.rtcp_mux) { - sdp_media.addAttr(std::make_shared("rtcp-mux")); - } - if (m.rtcp_rsize) { - sdp_media.addAttr(std::make_shared("rtcp-rsize")); - } - - if (m.type != TrackApplication) { - for (auto &p : m.plan) { - auto rtp_map = std::make_shared(); - rtp_map->pt = p.pt; - rtp_map->codec = p.codec; - rtp_map->sample_rate = p.sample_rate; - rtp_map->channel = p.channel; - // 添加a=rtpmap [AUTO-TRANSLATED:8bef5d64] - // Add a=rtpmap - sdp_media.addAttr(std::move(rtp_map)); - - for (auto &fb : p.rtcp_fb) { - auto rtcp_fb = std::make_shared(); - rtcp_fb->pt = p.pt; - rtcp_fb->rtcp_type = fb; - // 添加a=rtcp-fb [AUTO-TRANSLATED:11754b43] - // Add a=rtcp-fb - sdp_media.addAttr(std::move(rtcp_fb)); - } - - if (!p.fmtp.empty()) { - auto fmtp = std::make_shared(); - fmtp->pt = p.pt; - fmtp->fmtp = p.fmtp; - // 添加a=fmtp [AUTO-TRANSLATED:594a4425] - // Add a=fmtp - sdp_media.addAttr(std::move(fmtp)); - } - } - - { - // 添加a=msid字段 [AUTO-TRANSLATED:cf2c1471] - // Add a=msid field - if (!m.rtp_rtx_ssrc.empty()) { - if (!m.rtp_rtx_ssrc[0].msid.empty()) { - auto msid = std::make_shared(); - msid->parse(m.rtp_rtx_ssrc[0].msid); - sdp_media.addAttr(std::move(msid)); - } - } - } - - { - for (auto &ssrc : m.rtp_rtx_ssrc) { - // 添加a=ssrc字段 [AUTO-TRANSLATED:75ca5225] - // Add a=ssrc field - CHECK(!ssrc.empty()); - addSdpAttrSSRC(ssrc, sdp_media, ssrc.ssrc); - if (ssrc.rtx_ssrc) { - addSdpAttrSSRC(ssrc, sdp_media, ssrc.rtx_ssrc); - - // 生成a=ssrc-group:FID字段 [AUTO-TRANSLATED:22b1f966] - // Generate a=ssrc-group:FID field - // 有rtx ssrc [AUTO-TRANSLATED:fece8076] - // There is rtx ssrc - auto group = std::make_shared(); - group->type = "FID"; - group->ssrcs.emplace_back(ssrc.ssrc); - group->ssrcs.emplace_back(ssrc.rtx_ssrc); - sdp_media.addAttr(std::move(group)); - } - } - } - - { - if (m.rtp_ssrc_sim.size() >= 2) { - // simulcast 要求 2~3路 [AUTO-TRANSLATED:3237ffca] - // Simulcast requires 2~3 channels - auto group = std::make_shared(); - for (auto &ssrc : m.rtp_ssrc_sim) { - group->ssrcs.emplace_back(ssrc.ssrc); - } - // 添加a=ssrc-group:SIM字段 [AUTO-TRANSLATED:46b04aae] - // Add a=ssrc-group:SIM field - group->type = "SIM"; - sdp_media.addAttr(std::move(group)); - } - - if (m.rtp_rids.size() >= 2) { - auto simulcast = std::make_shared(); - simulcast->direction = "recv"; - simulcast->rids = m.rtp_rids; - sdp_media.addAttr(std::move(simulcast)); - - for (auto &rid : m.rtp_rids) { - auto attr_rid = std::make_shared(); - attr_rid->rid = rid; - attr_rid->direction = "recv"; - sdp_media.addAttr(std::move(attr_rid)); - } - } - } - - } else { - if (!m.sctpmap.empty()) { - sdp_media.addAttr(std::make_shared(m.sctpmap)); - } - sdp_media.addAttr(std::make_shared("sctp-port", to_string(m.sctp_port))); - } - - for (auto &cand : m.candidate) { - if (cand.port) { - sdp_media.addAttr(std::make_shared(cand)); - } - } - } - if (ice_lite) { - sdp.addAttr(std::make_shared("ice-lite")); - } - return ret; -} - -string RtcSession::toString() const { - return toRtcSessionSdp()->toString(); -} - -string RtcCodecPlan::getFmtp(const char *key) const { - for (auto &item : fmtp) { - if (strcasecmp(item.first.data(), key) == 0) { - return item.second; - } - } - return ""; -} - -const RtcCodecPlan *RtcMedia::getPlan(uint8_t pt) const { - for (auto &item : plan) { - if (item.pt == pt) { - return &item; - } - } - return nullptr; -} - -const RtcCodecPlan *RtcMedia::getPlan(const char *codec) const { - for (auto &item : plan) { - if (strcasecmp(item.codec.data(), codec) == 0) { - return &item; - } - } - return nullptr; -} - -const RtcCodecPlan *RtcMedia::getRelatedRtxPlan(uint8_t pt) const { - for (auto &item : plan) { - if (strcasecmp(item.codec.data(), "rtx") == 0) { - auto apt = atoi(item.getFmtp("apt").data()); - if (pt == apt) { - return &item; - } - } - } - return nullptr; -} - -uint32_t RtcMedia::getRtpSSRC() const { - if (rtp_rtx_ssrc.size()) { - return rtp_rtx_ssrc[0].ssrc; - } - return 0; -} - -uint32_t RtcMedia::getRtxSSRC() const { - if (rtp_rtx_ssrc.size()) { - return rtp_rtx_ssrc[0].rtx_ssrc; - } - return 0; -} - -bool RtcMedia::supportSimulcast() const { - if (!rtp_rids.empty()) { - return true; - } - if (!rtp_ssrc_sim.empty()) { - return true; - } - return false; -} - -void RtcMedia::checkValid() const { - CHECK(type != TrackInvalid); - CHECK(!mid.empty()); - CHECK(!proto.empty()); - CHECK(direction != RtpDirection::invalid || type == TrackApplication); - CHECK(!plan.empty() || type == TrackApplication); - CHECK(type == TrackApplication || rtcp_mux, "只支持rtcp-mux模式"); - - bool send_rtp = (direction == RtpDirection::sendonly || direction == RtpDirection::sendrecv); - if (!supportSimulcast()) { - // 非simulcast时,检查有没有指定rtp ssrc [AUTO-TRANSLATED:e2d53f8a] - // When not simulcast, check if the RTP SSRC is specified - CHECK(!rtp_rtx_ssrc.empty() || !send_rtp); - } - -#if 0 - // todo 发现Firefox(88.0)在mac平台下,开启rtx后没有指定ssrc [AUTO-TRANSLATED:9112d91a] - // todo Found that Firefox (88.0) on the mac platform does not specify ssrc when rtx is enabled - auto rtx_plan = getPlan("rtx"); - if (rtx_plan) { - // 开启rtx后必须指定rtx_ssrc [AUTO-TRANSLATED:c527f68d] - // RTX must be specified after rtx_ssrc is enabled - CHECK(rtp_rtx_ssrc.size() >= 2 || !send_rtp); - } -#endif -} - -void RtcSession::checkValid() const { - CHECK(version == 0); - CHECK(!origin.empty()); - CHECK(!session_name.empty()); - CHECK(!msid_semantic.empty()); - CHECK(!media.empty()); - CHECK(!group.mids.empty() && group.mids.size() <= media.size(), "只支持group BUNDLE模式"); - - bool have_active_media = false; - for (auto &item : media) { - item.checkValid(); - - if (TrackApplication == item.type) { - have_active_media = true; - } - switch (item.direction) { - case RtpDirection::sendrecv: - case RtpDirection::sendonly: - case RtpDirection::recvonly: have_active_media = true; break; - default: break; - } - } - CHECK(have_active_media, "必须确保最少有一个活跃的track"); -} - -const RtcMedia *RtcSession::getMedia(TrackType type) const { - for (auto &m : media) { - if (m.type == type) { - return &m; - } - } - return nullptr; -} - -bool RtcSession::supportRtcpFb(const string &name, TrackType type) const { - auto media = getMedia(type); - if (!media) { - return false; - } - auto &ref = media->plan[0].rtcp_fb; - return ref.find(name) != ref.end(); -} - -bool RtcSession::supportSimulcast() const { - for (auto &m : media) { - if (m.supportSimulcast()) { - return true; - } - } - return false; -} - -bool RtcSession::isOnlyDatachannel() const { - return 1 == media.size() && TrackApplication == media[0].type; -} - -string const SdpConst::kTWCCRtcpFb = "transport-cc"; -string const SdpConst::kRembRtcpFb = "goog-remb"; - -void RtcConfigure::RtcTrackConfigure::enableTWCC(bool enable) { - if (!enable) { - rtcp_fb.erase(SdpConst::kTWCCRtcpFb); - extmap.erase(RtpExtType::transport_cc); - } else { - rtcp_fb.emplace(SdpConst::kTWCCRtcpFb); - extmap.emplace(RtpExtType::transport_cc, RtpDirection::sendrecv); - } -} - -void RtcConfigure::RtcTrackConfigure::enableREMB(bool enable) { - if (!enable) { - rtcp_fb.erase(SdpConst::kRembRtcpFb); - extmap.erase(RtpExtType::abs_send_time); - } else { - rtcp_fb.emplace(SdpConst::kRembRtcpFb); - extmap.emplace(RtpExtType::abs_send_time, RtpDirection::sendrecv); - } -} - -static vector toCodecArray(const string &str) { - vector ret; - auto vec = split(str, ","); - for (auto &s : vec) { - auto codec = getCodecId(trim(s)); - if (codec != CodecInvalid) { - ret.emplace_back(codec); - } - } - return ret; -} - -void RtcConfigure::RtcTrackConfigure::setDefaultSetting(TrackType type) { - rtcp_mux = true; - rtcp_rsize = false; - group_bundle = true; - support_rtx = true; - support_red = false; - support_ulpfec = false; - ice_lite = true; - ice_trickle = true; - ice_renomination = false; - switch (type) { - case TrackAudio: { - // 此处调整偏好的编码格式优先级 [AUTO-TRANSLATED:b8719e66] - // Adjust the priority of preferred encoding formats here - GET_CONFIG_FUNC(vector, s_preferred_codec, Rtc::kPreferredCodecA, toCodecArray); - CHECK(!s_preferred_codec.empty(), "rtc音频偏好codec不能为空"); - preferred_codec = s_preferred_codec; - - rtcp_fb = { SdpConst::kTWCCRtcpFb, SdpConst::kRembRtcpFb }; - extmap = { { RtpExtType::ssrc_audio_level, RtpDirection::sendrecv }, - { RtpExtType::csrc_audio_level, RtpDirection::sendrecv }, - { RtpExtType::abs_send_time, RtpDirection::sendrecv }, - { RtpExtType::transport_cc, RtpDirection::sendrecv }, - // rtx重传rtp时,忽略sdes_mid类型的rtp ext,实测发现Firefox在接收rtx时,如果存在sdes_mid的ext,将导致无法播放 [AUTO-TRANSLATED:221df025] - // When rtx retransmits rtp, ignore the rtp ext of sdes_mid type. It is found that Firefox cannot play when receiving rtx if there is an ext of sdes_mid - //{RtpExtType::sdes_mid,RtpDirection::sendrecv}, - { RtpExtType::sdes_rtp_stream_id, RtpDirection::sendrecv }, - { RtpExtType::sdes_repaired_rtp_stream_id, RtpDirection::sendrecv } }; - break; - } - case TrackVideo: { - // 此处调整偏好的编码格式优先级 [AUTO-TRANSLATED:b8719e66] - // Adjust the priority of preferred encoding formats here - GET_CONFIG_FUNC(vector, s_preferred_codec, Rtc::kPreferredCodecV, toCodecArray); - CHECK(!s_preferred_codec.empty(), "rtc视频偏好codec不能为空"); - preferred_codec = s_preferred_codec; - - rtcp_fb = { SdpConst::kTWCCRtcpFb, SdpConst::kRembRtcpFb, "nack", "ccm fir", "nack pli" }; - extmap = { { RtpExtType::abs_send_time, RtpDirection::sendrecv }, - { RtpExtType::transport_cc, RtpDirection::sendrecv }, - // rtx重传rtp时,忽略sdes_mid类型的rtp ext,实测发现Firefox在接收rtx时,如果存在sdes_mid的ext,将导致无法播放 [AUTO-TRANSLATED:221df025] - // When rtx retransmits rtp, ignore the rtp ext of sdes_mid type. It is found that Firefox cannot play when receiving rtx if there is an ext of sdes_mid - //{RtpExtType::sdes_mid,RtpDirection::sendrecv}, - { RtpExtType::sdes_rtp_stream_id, RtpDirection::sendrecv }, - { RtpExtType::sdes_repaired_rtp_stream_id, RtpDirection::sendrecv }, - { RtpExtType::video_timing, RtpDirection::sendrecv }, - { RtpExtType::color_space, RtpDirection::sendrecv }, - { RtpExtType::video_content_type, RtpDirection::sendrecv }, - { RtpExtType::playout_delay, RtpDirection::sendrecv }, - // 手机端推webrtc 会带有旋转角度,rtc协议能正常播放 其他协议拉流画面旋转 [AUTO-TRANSLATED:3f2f9e0e] - // Mobile push webrtc will have a rotation angle, rtc protocol can play normally, other protocols pull stream picture rotation - //{RtpExtType::video_orientation, RtpDirection::sendrecv}, - { RtpExtType::toffset, RtpDirection::sendrecv }, - { RtpExtType::framemarking, RtpDirection::sendrecv } }; - break; - } - case TrackApplication: { - break; - } - default: break; - } -} - -void RtcConfigure::setDefaultSetting(string ice_ufrag, string ice_pwd, RtpDirection direction, const SdpAttrFingerprint &fingerprint) { - video.setDefaultSetting(TrackVideo); - audio.setDefaultSetting(TrackAudio); - application.setDefaultSetting(TrackApplication); - - video.ice_ufrag = audio.ice_ufrag = application.ice_ufrag = std::move(ice_ufrag); - video.ice_pwd = audio.ice_pwd = application.ice_pwd = std::move(ice_pwd); - video.direction = audio.direction = application.direction = direction; - video.fingerprint = audio.fingerprint = application.fingerprint = fingerprint; -} - -void RtcConfigure::addCandidate(const SdpAttrCandidate &candidate, TrackType type) { - switch (type) { - case TrackAudio: { - audio.candidate.emplace_back(candidate); - break; - } - case TrackVideo: { - video.candidate.emplace_back(candidate); - break; - } - case TrackApplication: { - application.candidate.emplace_back(candidate); - break; - } - default: { - if (audio.group_bundle) { - audio.candidate.emplace_back(candidate); - } - if (video.group_bundle) { - video.candidate.emplace_back(candidate); - } - if (application.group_bundle) { - application.candidate.emplace_back(candidate); - } - break; - } - } -} - -void RtcConfigure::enableTWCC(bool enable, TrackType type) { - switch (type) { - case TrackAudio: { - audio.enableTWCC(enable); - break; - } - case TrackVideo: { - video.enableTWCC(enable); - break; - } - default: { - audio.enableTWCC(enable); - video.enableTWCC(enable); - break; - } - } -} - -void RtcConfigure::enableREMB(bool enable, TrackType type) { - switch (type) { - case TrackAudio: { - audio.enableREMB(enable); - break; - } - case TrackVideo: { - video.enableREMB(enable); - break; - } - default: { - audio.enableREMB(enable); - video.enableREMB(enable); - break; - } - } -} - -shared_ptr RtcConfigure::createAnswer(const RtcSession &offer) const { - shared_ptr ret = std::make_shared(); - ret->version = offer.version; - ret->origin = offer.origin; - ret->session_name = offer.session_name; - ret->msid_semantic = offer.msid_semantic; - - for (auto &m : offer.media) { - matchMedia(ret, m); - } - - // 设置音视频端口复用 [AUTO-TRANSLATED:ffe27d17] - // Set audio and video port multiplexing - if (!offer.group.mids.empty()) { - for (auto &m : ret->media) { - // The remote end has rejected (port 0) the m-section, so it should not be putting its mid in the group attribute. - if (m.port) { - ret->group.mids.emplace_back(m.mid); - } - } - } - return ret; -} - -static RtpDirection matchDirection(RtpDirection offer_direction, RtpDirection supported) { - switch (offer_direction) { - case RtpDirection::sendonly: { - if (supported != RtpDirection::recvonly && supported != RtpDirection::sendrecv) { - // 我们不支持接收 [AUTO-TRANSLATED:e4ef4034] - // We do not support receiving - return RtpDirection::inactive; - } - return RtpDirection::recvonly; - } - - case RtpDirection::recvonly: { - if (supported != RtpDirection::sendonly && supported != RtpDirection::sendrecv) { - // 我们不支持发送 [AUTO-TRANSLATED:6505a226] - // We do not support sending - return RtpDirection::inactive; - } - return RtpDirection::sendonly; - } - - // 对方支持发送接收,那么最终能力根据配置来决定 [AUTO-TRANSLATED:d234d603] - // The other party supports sending and receiving, so the final capability is determined by the configuration - case RtpDirection::sendrecv: return (supported == RtpDirection::invalid ? RtpDirection::inactive : supported); - case RtpDirection::inactive: return RtpDirection::inactive; - default: return RtpDirection::invalid; - } -} - -static DtlsRole mathDtlsRole(DtlsRole role) { - switch (role) { - case DtlsRole::actpass: - case DtlsRole::active: return DtlsRole::passive; - case DtlsRole::passive: return DtlsRole::active; - default: CHECK(0, "invalid role:", getDtlsRoleString(role)); return DtlsRole::passive; - } -} - -void RtcConfigure::matchMedia(const std::shared_ptr &ret, const RtcMedia &offer_media) const { - bool check_profile = true; - bool check_codec = true; - const RtcTrackConfigure *cfg_ptr = nullptr; - - switch (offer_media.type) { - case TrackAudio: cfg_ptr = &audio; break; - case TrackVideo: cfg_ptr = &video; break; - case TrackApplication: cfg_ptr = &application; break; - default: return; - } - auto &configure = *cfg_ptr; - -RETRY: - - if (offer_media.type == TrackApplication) { - RtcMedia answer_media = offer_media; - answer_media.role = mathDtlsRole(offer_media.role); - answer_media.ice_ufrag = configure.ice_ufrag; - answer_media.ice_pwd = configure.ice_pwd; - answer_media.fingerprint = configure.fingerprint; - answer_media.ice_lite = configure.ice_lite; -#ifdef ENABLE_SCTP - answer_media.candidate = configure.candidate; -#else - answer_media.port = 0; - WarnL << "answer sdp忽略application mline, 请安装usrsctp后再测试datachannel功能"; -#endif - ret->media.emplace_back(answer_media); - return; - } - for (auto &codec : configure.preferred_codec) { - if (offer_media.ice_lite && configure.ice_lite) { - WarnL << "answer sdp配置为ice_lite模式,与offer sdp中的ice_lite模式冲突"; - continue; - } - const RtcCodecPlan *selected_plan = nullptr; - for (auto &plan : offer_media.plan) { - // 先检查编码格式是否为偏好 [AUTO-TRANSLATED:b7fb32a0] - // First check if the encoding format is preferred - if (check_codec && getCodecId(plan.codec) != codec) { - continue; - } - // 命中偏好的编码格式,然后检查规格 [AUTO-TRANSLATED:a859c839] - // Hit the preferred encoding format, then check the specifications - if (check_profile && !onCheckCodecProfile(plan, codec)) { - continue; - } - // 找到中意的codec [AUTO-TRANSLATED:4b5eebfd] - // Find the desired codec - selected_plan = &plan; - break; - } - if (!selected_plan) { - // offer中该媒体的所有的codec都不支持 [AUTO-TRANSLATED:3b57b86f] - // All codecs for this media in the offer are not supported - continue; - } - RtcMedia answer_media; - answer_media.type = offer_media.type; - answer_media.mid = offer_media.mid; - answer_media.proto = offer_media.proto; - answer_media.port = offer_media.port; - answer_media.addr = offer_media.addr; - answer_media.bandwidth = offer_media.bandwidth; - answer_media.rtcp_addr = offer_media.rtcp_addr; - answer_media.rtcp_mux = offer_media.rtcp_mux && configure.rtcp_mux; - answer_media.rtcp_rsize = offer_media.rtcp_rsize && configure.rtcp_rsize; - answer_media.ice_trickle = offer_media.ice_trickle && configure.ice_trickle; - answer_media.ice_renomination = offer_media.ice_renomination && configure.ice_renomination; - answer_media.ice_ufrag = configure.ice_ufrag; - answer_media.ice_pwd = configure.ice_pwd; - answer_media.fingerprint = configure.fingerprint; - answer_media.ice_lite = configure.ice_lite; - answer_media.candidate = configure.candidate; - // copy simulicast setting - answer_media.rtp_rids = offer_media.rtp_rids; - answer_media.rtp_ssrc_sim = offer_media.rtp_ssrc_sim; - - answer_media.role = mathDtlsRole(offer_media.role); - - // 如果codec匹配失败,那么禁用该track [AUTO-TRANSLATED:037de9a8] - // If the codec matching fails, then disable the track - answer_media.direction = check_codec ? matchDirection(offer_media.direction, configure.direction) : RtpDirection::inactive; - if (answer_media.direction == RtpDirection::invalid) { - continue; - } - if (answer_media.direction == RtpDirection::sendrecv) { - // 如果是收发双向,那么我们拷贝offer sdp的ssrc,确保ssrc一致 [AUTO-TRANSLATED:d4a621f2] - // If it is bidirectional, then we copy the offer sdp ssrc to ensure ssrc consistency - answer_media.rtp_rtx_ssrc = offer_media.rtp_rtx_ssrc; - } - - // 添加媒体plan [AUTO-TRANSLATED:3f730050] - // Add media plan - answer_media.plan.emplace_back(*selected_plan); - onSelectPlan(answer_media.plan.back(), codec); - - set pt_selected = { selected_plan->pt }; - - // 添加rtx,red,ulpfec plan [AUTO-TRANSLATED:1abff0c1] - // Add rtx, red, ulpfec plan - if (configure.support_red || configure.support_rtx || configure.support_ulpfec) { - for (auto &plan : offer_media.plan) { - if (!strcasecmp(plan.codec.data(), "rtx")) { - if (configure.support_rtx && atoi(plan.getFmtp("apt").data()) == selected_plan->pt) { - answer_media.plan.emplace_back(plan); - pt_selected.emplace(plan.pt); - } - continue; - } - if (!strcasecmp(plan.codec.data(), "red")) { - if (configure.support_red) { - answer_media.plan.emplace_back(plan); - pt_selected.emplace(plan.pt); - } - continue; - } - if (!strcasecmp(plan.codec.data(), "ulpfec")) { - if (configure.support_ulpfec) { - answer_media.plan.emplace_back(plan); - pt_selected.emplace(plan.pt); - } - continue; - } - } - } - - // 对方和我方都支持的扩展,那么我们才支持 [AUTO-TRANSLATED:a6cd98b2] - // We only support extensions that are supported by both the other party and us - for (auto &ext : offer_media.extmap) { - auto it = configure.extmap.find(RtpExt::getExtType(ext.ext)); - if (it != configure.extmap.end()) { - auto new_dir = matchDirection(ext.direction, it->second); - switch (new_dir) { - case RtpDirection::invalid: - case RtpDirection::inactive: continue; - default: break; - } - answer_media.extmap.emplace_back(ext); - answer_media.extmap.back().direction = new_dir; - } - } - - auto &rtcp_fb_ref = answer_media.plan[0].rtcp_fb; - rtcp_fb_ref.clear(); - // 对方和我方都支持的rtcpfb,那么我们才支持 [AUTO-TRANSLATED:f10450bb] - // We only support rtcpfb that is supported by both the other party and us - for (auto &fp : selected_plan->rtcp_fb) { - if (configure.rtcp_fb.find(fp) != configure.rtcp_fb.end()) { - // 对方该rtcp被我们支持 [AUTO-TRANSLATED:3b16e666] - // The other party's rtcp is supported by us - rtcp_fb_ref.emplace(fp); - } - } - -#if 0 - // todo 此处为添加无效的plan,webrtc sdp通过调节plan pt顺序选择匹配的codec,意味着后面的codec其实放在sdp中是无意义的 [AUTO-TRANSLATED:502d0cb2] - // todo This is to add an invalid plan. WebRTC sdp selects the matching codec by adjusting the plan pt order, which means that the subsequent codecs are actually meaningless in the sdp - for (auto &plan : offer_media.plan) { - if (pt_selected.find(plan.pt) == pt_selected.end()) { - answer_media.plan.emplace_back(plan); - } - } -#endif - ret->media.emplace_back(answer_media); - return; - } - - if (check_profile) { - // 如果是由于检查profile导致匹配失败,那么重试一次,且不检查profile [AUTO-TRANSLATED:897fa4ae] - // If the matching fails due to profile check, retry once and do not check profile - check_profile = false; - goto RETRY; - } - - if (check_codec) { - // 如果是由于检查codec导致匹配失败,那么重试一次,且不检查codec [AUTO-TRANSLATED:fbd85968] - // If the matching fails due to codec check, retry once and do not check codec - check_codec = false; - goto RETRY; - } -} - -void RtcConfigure::setPlayRtspInfo(const string &sdp) { - RtcSession session; - video.direction = RtpDirection::inactive; - audio.direction = RtpDirection::inactive; - - session.loadFrom(sdp); - for (auto &m : session.media) { - switch (m.type) { - case TrackVideo: { - video.direction = RtpDirection::sendonly; - _rtsp_video_plan = std::make_shared(m.plan[0]); - video.preferred_codec.clear(); - video.preferred_codec.emplace_back(getCodecId(_rtsp_video_plan->codec)); - break; - } - case TrackAudio: { - audio.direction = RtpDirection::sendonly; - _rtsp_audio_plan = std::make_shared(m.plan[0]); - audio.preferred_codec.clear(); - audio.preferred_codec.emplace_back(getCodecId(_rtsp_audio_plan->codec)); - break; - } - default: break; - } - } -} - -static const string kProfile { "profile-level-id" }; -static const string kMode { "packetization-mode" }; - -bool RtcConfigure::onCheckCodecProfile(const RtcCodecPlan &plan, CodecId codec) const { - if (_rtsp_audio_plan && codec == getCodecId(_rtsp_audio_plan->codec)) { - if (plan.sample_rate != _rtsp_audio_plan->sample_rate || plan.channel != _rtsp_audio_plan->channel) { - // 音频采样率和通道数必须相同 [AUTO-TRANSLATED:6e591932] - // Audio sampling rate and number of channels must be the same - return false; - } - return true; - } - if (_rtsp_video_plan && codec == CodecH264 && getCodecId(_rtsp_video_plan->codec) == CodecH264) { - // h264时,profile-level-id [AUTO-TRANSLATED:94a5f360] - // When h264, profile-level-id - if (strcasecmp(_rtsp_video_plan->fmtp[kProfile].data(), const_cast(plan).fmtp[kProfile].data())) { - // profile-level-id 不匹配 [AUTO-TRANSLATED:814ec4c4] - // profile-level-id does not match - return false; - } - return true; - } - - return true; -} - -/** - Single NAI Unit Mode = 0. // Single NAI mode (Only nals from 1-23 are allowed) - Non Interleaved Mode = 1,// Non-interleaved Mode: 1-23,24 (STAP-A),28 (FU-A) are allowed - Interleaved Mode = 2, // 25 (STAP-B),26 (MTAP16),27 (MTAP24),28 (EU-A),and 29 (EU-B) are allowed. - Single NAI Unit Mode = 0. // Single NAI mode (Only nals from 1-23 are allowed) - Non Interleaved Mode = 1,// Non-interleaved Mode: 1-23,24 (STAP-A),28 (FU-A) are allowed - Interleaved Mode = 2, // 25 (STAP-B),26 (MTAP16),27 (MTAP24),28 (EU-A),and 29 (EU-B) are allowed. - * - * [AUTO-TRANSLATED:b1526114] - **/ -void RtcConfigure::onSelectPlan(RtcCodecPlan &plan, CodecId codec) const { - if (_rtsp_video_plan && codec == CodecH264 && getCodecId(_rtsp_video_plan->codec) == CodecH264) { - // h264时,设置packetization-mod为一致 [AUTO-TRANSLATED:59a00889] - // When h264, set packetization-mod to be consistent - auto mode = _rtsp_video_plan->fmtp[kMode]; - GET_CONFIG(bool, h264_stap_a, Rtp::kH264StapA); - plan.fmtp[kMode] = mode.empty() ? std::to_string(h264_stap_a) : mode; - } -} - -} // namespace mediakit \ No newline at end of file +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "Sdp.h" +#include "Rtsp/Rtsp.h" +#include "Common/config.h" +#include + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +namespace Rtc { +#define RTC_FIELD "rtc." +const string kPreferredCodecA = RTC_FIELD "preferredCodecA"; +const string kPreferredCodecV = RTC_FIELD "preferredCodecV"; +static onceToken token([]() { + mINI::Instance()[kPreferredCodecA] = "PCMA,PCMU,opus,mpeg4-generic"; + mINI::Instance()[kPreferredCodecV] = "H264,H265,AV1,VP9,VP8"; +}); +} // namespace Rtc + +using onCreateSdpItem = function; +static map sdpItemCreator; + +template +void registerSdpItem() { + onCreateSdpItem func = [](const string &key, const string &value) { + auto ret = std::make_shared(); + ret->parse(value); + return ret; + }; + Item item; + sdpItemCreator.emplace(item.getKey(), std::move(func)); +} + +class DirectionInterface { +public: + virtual RtpDirection getDirection() const = 0; +}; + +class SdpDirectionSendonly : public SdpItem, public DirectionInterface { +public: + const char *getKey() const override { return getRtpDirectionString(getDirection()); } + RtpDirection getDirection() const override { return RtpDirection::sendonly; } +}; + +class SdpDirectionRecvonly : public SdpItem, public DirectionInterface { +public: + const char *getKey() const override { return getRtpDirectionString(getDirection()); } + RtpDirection getDirection() const override { return RtpDirection::recvonly; } +}; + +class SdpDirectionSendrecv : public SdpItem, public DirectionInterface { +public: + const char *getKey() const override { return getRtpDirectionString(getDirection()); } + RtpDirection getDirection() const override { return RtpDirection::sendrecv; } +}; + +class SdpDirectionInactive : public SdpItem, public DirectionInterface { +public: + const char *getKey() const override { return getRtpDirectionString(getDirection()); } + RtpDirection getDirection() const override { return RtpDirection::inactive; } +}; + +class DirectionInterfaceImp : public SdpItem, public DirectionInterface { +public: + DirectionInterfaceImp(RtpDirection direct) { direction = direct; } + const char *getKey() const override { return getRtpDirectionString(getDirection()); } + RtpDirection getDirection() const override { return direction; } + +private: + RtpDirection direction; +}; + +static bool registerAllItem() { + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem>(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + registerSdpItem(); + return true; +} + +static map dtls_role_map = { + {"active", DtlsRole::active}, + {"passive", DtlsRole::passive}, + {"actpass", DtlsRole::actpass} +}; + +DtlsRole getDtlsRole(const string &str) { + auto it = dtls_role_map.find(str); + return it == dtls_role_map.end() ? DtlsRole::invalid : it->second; +} + +const char *getDtlsRoleString(DtlsRole role) { + switch (role) { + case DtlsRole::active: return "active"; + case DtlsRole::passive: return "passive"; + case DtlsRole::actpass: return "actpass"; + default: return "invalid"; + } +} + +static map direction_map = { + {"sendonly", RtpDirection::sendonly}, + {"recvonly", RtpDirection::recvonly}, + {"sendrecv", RtpDirection::sendrecv}, + {"inactive", RtpDirection::inactive} +}; + +RtpDirection getRtpDirection(const string &str) { + auto it = direction_map.find(str); + return it == direction_map.end() ? RtpDirection::invalid : it->second; +} + +const char *getRtpDirectionString(RtpDirection val) { + switch (val) { + case RtpDirection::sendonly: return "sendonly"; + case RtpDirection::recvonly: return "recvonly"; + case RtpDirection::sendrecv: return "sendrecv"; + case RtpDirection::inactive: return "inactive"; + default: return "invalid"; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// + +string RtcSdpBase::toString() const { + _StrPrinter printer; + for (auto &item : items) { + printer << item->getKey() << "=" << item->toString() << "\r\n"; + } + return printer; +} + +RtpDirection RtcSdpBase::getDirection() const { + for (auto &item : items) { + auto attr = dynamic_pointer_cast(item); + if (attr) { + auto dir = dynamic_pointer_cast(attr->detail); + if (dir) { + return dir->getDirection(); + } + } + } + return RtpDirection::invalid; +} + +SdpItem::Ptr RtcSdpBase::getItem(char key_c, const char *attr_key) const { + std::string key(1, key_c); + for (auto item : items) { + if (strcasecmp(item->getKey(), key.data()) == 0) { + if (!attr_key) { + return item; + } + auto attr = dynamic_pointer_cast(item); + if (attr && !strcasecmp(attr->detail->getKey(), attr_key)) { + return attr->detail; + } + } + } + return SdpItem::Ptr(); +} + +////////////////////////////////////////////////////////////////////////// +int RtcSessionSdp::getVersion() const { + return atoi(getStringItem('v').data()); +} + +SdpOrigin RtcSessionSdp::getOrigin() const { + return getItemClass('o'); +} + +string RtcSessionSdp::getSessionName() const { + return getStringItem('s'); +} + +string RtcSessionSdp::getSessionInfo() const { + return getStringItem('i'); +} + +SdpTime RtcSessionSdp::getSessionTime() const { + return getItemClass('t'); +} + +SdpConnection RtcSessionSdp::getConnection() const { + return getItemClass('c'); +} + +SdpBandwidth RtcSessionSdp::getBandwidth() const { + return getItemClass('b'); +} + +string RtcSessionSdp::getUri() const { + return getStringItem('u'); +} + +string RtcSessionSdp::getEmail() const { + return getStringItem('e'); +} + +string RtcSessionSdp::getPhone() const { + return getStringItem('p'); +} + +string RtcSessionSdp::getTimeZone() const { + return getStringItem('z'); +} + +string RtcSessionSdp::getEncryptKey() const { + return getStringItem('k'); +} + +string RtcSessionSdp::getRepeatTimes() const { + return getStringItem('r'); +} + +////////////////////////////////////////////////////////////////////// + +void RtcSessionSdp::parse(const string &str) { + static auto flag = registerAllItem(); + RtcSdpBase *media = nullptr; + auto lines = split(str, "\n"); + std::set line_set; + for (auto &line : lines) { + trim(line); + if (line.size() < 3 || line[1] != '=') { + continue; + } + + if (!line_set.emplace(line).second) { + continue; + } + + auto key = line.substr(0, 1); + auto value = line.substr(2); + if (!strcasecmp(key.data(), "m")) { + medias.emplace_back(RtcSdpBase()); + media = &medias.back(); + line_set.clear(); + } + + SdpItem::Ptr item; + auto it = sdpItemCreator.find(key); + if (it != sdpItemCreator.end()) { + item = it->second(key, value); + } else { + item = std::make_shared(key); + item->parse(value); + } + if (media) { + media->addItem(std::move(item)); + } else { + addItem(std::move(item)); + } + } +} + +string RtcSessionSdp::toString() const { + _StrPrinter printer; + printer << RtcSdpBase::toString(); + for (auto &media : medias) { + printer << media.toString(); + } + + return printer; +} + +////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_SDP(exp) CHECK(exp, "解析sdp ", getKey(), " 字段失败:", str) + +void SdpTime::parse(const string &str) { + CHECK_SDP(sscanf(str.data(), "%" SCNu64 " %" SCNu64, &start, &stop) == 2); +} + +string SdpTime::toString() const { + if (value.empty()) { + value = to_string(start) + " " + to_string(stop); + } + return SdpItem::toString(); +} + +void SdpOrigin::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() == 6); + username = vec[0]; + session_id = vec[1]; + session_version = vec[2]; + nettype = vec[3]; + addrtype = vec[4]; + address = vec[5]; +} + +string SdpOrigin::toString() const { + if (value.empty()) { + value = username + " " + session_id + " " + session_version + " " + nettype + " " + addrtype + " " + address; + } + return SdpItem::toString(); +} + +void SdpConnection::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() == 3); + nettype = vec[0]; + addrtype = vec[1]; + address = vec[2]; +} + +string SdpConnection::toString() const { + if (value.empty()) { + value = nettype + " " + addrtype + " " + address; + } + return SdpItem::toString(); +} + +void SdpBandwidth::parse(const string &str) { + auto vec = split(str, ":"); + CHECK_SDP(vec.size() == 2); + bwtype = vec[0]; + bandwidth = atoi(vec[1].data()); +} + +string SdpBandwidth::toString() const { + if (value.empty()) { + value = bwtype + ":" + to_string(bandwidth); + } + return SdpItem::toString(); +} + +void SdpMedia::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() >= 4); + type = getTrackType(vec[0]); + CHECK_SDP(type != TrackInvalid); + port = atoi(vec[1].data()); + proto = vec[2]; + for (size_t i = 3; i < vec.size(); ++i) { + fmts.emplace_back(vec[i]); + } +} + +string SdpMedia::toString() const { + if (value.empty()) { + value = string(getTrackString(type)) + " " + to_string(port) + " " + proto; + for (auto fmt : fmts) { + value += ' '; + value += fmt; + } + } + return SdpItem::toString(); +} + +void SdpAttr::parse(const string &str) { + auto pos = str.find(':'); + auto key = pos == string::npos ? str : str.substr(0, pos); + auto value = pos == string::npos ? string() : str.substr(pos + 1); + auto it = sdpItemCreator.find(key); + if (it != sdpItemCreator.end()) { + detail = it->second(key, value); + } else { + detail = std::make_shared(key); + detail->parse(value); + } +} + +string SdpAttr::toString() const { + if (value.empty()) { + auto detail_value = detail->toString(); + if (detail_value.empty()) { + value = detail->getKey(); + } else { + value = string(detail->getKey()) + ":" + detail_value; + } + } + return SdpItem::toString(); +} + +void SdpAttrGroup::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() >= 2); + type = vec[0]; + vec.erase(vec.begin()); + mids = std::move(vec); +} + +string SdpAttrGroup::toString() const { + if (value.empty()) { + value = type; + for (auto mid : mids) { + value += ' '; + value += mid; + } + } + return SdpItem::toString(); +} + +void SdpAttrMsidSemantic::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() >= 1); + msid = vec[0]; + token = vec.size() > 1 ? vec[1] : ""; +} + +string SdpAttrMsidSemantic::toString() const { + if (value.empty()) { + if (token.empty()) { + value = string(" ") + msid; + } else { + value = string(" ") + msid + " " + token; + } + } + return SdpItem::toString(); +} + +void SdpAttrRtcp::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() == 4); + port = atoi(vec[0].data()); + nettype = vec[1]; + addrtype = vec[2]; + address = vec[3]; +} + +string SdpAttrRtcp::toString() const { + if (value.empty()) { + value = to_string(port) + " " + nettype + " " + addrtype + " " + address; + } + return SdpItem::toString(); +} + +void SdpAttrIceOption::parse(const string &str) { + auto vec = split(str, " "); + for (auto &v : vec) { + if (!strcasecmp(v.data(), "trickle")) { + trickle = true; + continue; + } + if (!strcasecmp(v.data(), "renomination")) { + renomination = true; + continue; + } + } +} + +string SdpAttrIceOption::toString() const { + if (value.empty()) { + if (trickle && renomination) { + value = "trickle renomination"; + } else if (trickle) { + value = "trickle"; + } else if (renomination) { + value = "renomination"; + } + } + return value; +} + +void SdpAttrFingerprint::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() == 2); + algorithm = vec[0]; + hash = vec[1]; +} + +string SdpAttrFingerprint::toString() const { + if (value.empty()) { + value = algorithm + " " + hash; + } + return SdpItem::toString(); +} + +void SdpAttrSetup::parse(const string &str) { + role = getDtlsRole(str); + CHECK_SDP(role != DtlsRole::invalid); +} + +string SdpAttrSetup::toString() const { + if (value.empty()) { + value = getDtlsRoleString(role); + } + return SdpItem::toString(); +} + +void SdpAttrExtmap::parse(const string &str) { + char buf[128] = { 0 }; + char direction_buf[32] = { 0 }; + if (sscanf(str.data(), "%" SCNd8 "/%31[^ ] %127s", &id, direction_buf, buf) != 3) { + CHECK_SDP(sscanf(str.data(), "%" SCNd8 " %127s", &id, buf) == 2); + direction = RtpDirection::sendrecv; + } else { + direction = getRtpDirection(direction_buf); + } + ext = buf; +} + +string SdpAttrExtmap::toString() const { + if (value.empty()) { + if (direction == RtpDirection::invalid || direction == RtpDirection::sendrecv) { + value = to_string((int)id) + " " + ext; + } else { + value = to_string((int)id) + "/" + getRtpDirectionString(direction) + " " + ext; + } + } + return SdpItem::toString(); +} + +void SdpAttrRtpMap::parse(const string &str) { + char buf[32] = { 0 }; + if (sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32 "/%" SCNd32, &pt, buf, &sample_rate, &channel) != 4) { + CHECK_SDP(sscanf(str.data(), "%" SCNu8 " %31[^/]/%" SCNd32, &pt, buf, &sample_rate) == 3); + if (getTrackType(getCodecId(buf)) == TrackAudio) { + // 未指定通道数时,且为音频时,那么通道数默认为1 [AUTO-TRANSLATED:bd128fbd] + // If the number of channels is not specified and it is audio, the number of channels defaults to 1 + channel = 1; + } + } + codec = buf; +} + +string SdpAttrRtpMap::toString() const { + if (value.empty()) { + value = to_string((int)pt) + " " + codec + "/" + to_string(sample_rate); + if (channel) { + value += '/'; + value += to_string(channel); + } + } + return SdpItem::toString(); +} + +void SdpAttrRtcpFb::parse(const string &str_in) { + auto str = str_in + "\n"; + char rtcp_type_buf[32] = { 0 }; + CHECK_SDP(sscanf(str.data(), "%" SCNu8 " %31[^\n]", &pt, rtcp_type_buf) == 2); + rtcp_type = rtcp_type_buf; +} + +string SdpAttrRtcpFb::toString() const { + if (value.empty()) { + value = to_string((int)pt) + " " + rtcp_type; + } + return SdpItem::toString(); +} + +void SdpAttrFmtp::parse(const string &str) { + auto pos = str.find(' '); + CHECK_SDP(pos != string::npos); + pt = atoi(str.substr(0, pos).data()); + auto vec = split(str.substr(pos + 1), ";"); + for (auto &item : vec) { + trim(item); + auto pos = item.find('='); + if (pos == string::npos) { + fmtp.emplace(std::make_pair(item, "")); + } else { + fmtp.emplace(std::make_pair(item.substr(0, pos), item.substr(pos + 1))); + } + } + CHECK_SDP(!fmtp.empty()); +} + +string SdpAttrFmtp::toString() const { + if (value.empty()) { + value = to_string((int)pt); + int i = 0; + for (auto &pr : fmtp) { + value += (i++ ? ';' : ' '); + value += pr.first + "=" + pr.second; + } + } + return SdpItem::toString(); +} + +void SdpAttrSSRC::parse(const string &str_in) { + auto str = str_in + '\n'; + char attr_buf[32] = { 0 }; + char attr_val_buf[128] = { 0 }; + if (3 == sscanf(str.data(), "%" SCNu32 " %31[^:]:%127[^\n]", &ssrc, attr_buf, attr_val_buf)) { + attribute = attr_buf; + attribute_value = attr_val_buf; + } else if (2 == sscanf(str.data(), "%" SCNu32 " %31s[^\n]", &ssrc, attr_buf)) { + attribute = attr_buf; + } else { + CHECK_SDP(0); + } +} + +string SdpAttrSSRC::toString() const { + if (value.empty()) { + value = to_string(ssrc) + ' '; + value += attribute; + if (!attribute_value.empty()) { + value += ':'; + value += attribute_value; + } + } + return SdpItem::toString(); +} + +void SdpAttrSSRCGroup::parse(const string &str) { + auto vec = split(str, " "); + CHECK_SDP(vec.size() >= 3); + type = std::move(vec[0]); + CHECK(isFID() || isSIM()); + vec.erase(vec.begin()); + for (auto ssrc : vec) { + ssrcs.emplace_back((uint32_t)atoll(ssrc.data())); + } +} + +string SdpAttrSSRCGroup::toString() const { + if (value.empty()) { + value = type; + // 最少要求2个ssrc [AUTO-TRANSLATED:968acb83] + // At least 2 SSRCs are required + CHECK(ssrcs.size() >= 2); + for (auto &ssrc : ssrcs) { + value += ' '; + value += to_string(ssrc); + } + } + return SdpItem::toString(); +} + +void SdpAttrSctpMap::parse(const string &str) { + char subtypes_buf[64] = { 0 }; + CHECK_SDP(3 == sscanf(str.data(), "%" SCNu16 " %63[^ ] %" SCNd32, &port, subtypes_buf, &streams)); + subtypes = subtypes_buf; +} + +string SdpAttrSctpMap::toString() const { + if (value.empty()) { + value = to_string(port); + value += ' '; + value += subtypes; + value += ' '; + value += to_string(streams); + } + return SdpItem::toString(); +} + +void SdpAttrCandidate::parse(const string &str) { + char foundation_buf[40] = { 0 }; + char transport_buf[16] = { 0 }; + char address_buf[64] = { 0 }; + char type_buf[16] = { 0 }; + + // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 + CHECK_SDP(sscanf(str.data(), "%32[^ ] %" SCNu32 " %15[^ ] %" SCNu32 " %63[^ ] %" SCNu16 " typ %15[^ ]", + foundation_buf, &component, transport_buf, &priority, address_buf, &port, type_buf) == 7); + foundation = foundation_buf; + transport = transport_buf; + address = address_buf; + type = type_buf; + auto pos = str.find(type); + if (pos != string::npos) { + auto remain = str.substr(pos + type.size()); + trim(remain); + if (!remain.empty()) { + auto vec = split(remain, " "); + string key; + for (auto &item : vec) { + if (key.empty()) { + key = item; + } else { + arr.emplace_back(std::make_pair(std::move(key), std::move(item))); + } + } + } + } +} + +string SdpAttrCandidate::toString() const { + if (value.empty()) { + value = foundation + " " + to_string(component) + " " + transport + " " + to_string(priority) + " " + address + " " + to_string(port) + " typ " + type; + for (auto &pr : arr) { + value += ' '; + value += pr.first; + value += ' '; + value += pr.second; + } + } + return SdpItem::toString(); +} + +void SdpAttrSimulcast::parse(const string &str) { + // https://www.meetecho.com/blog/simulcast-janus-ssrc/ + // a=simulcast:send/recv q;h;f + // a=simulcast:send/recv [rid=]q;h;f + // a=simulcast: recv h;m;l + // + auto vec = split(str, " "); + CHECK_SDP(vec.size() == 2); + direction = vec[0]; + rids = split(vec[1], ";"); +} + +string SdpAttrSimulcast::toString() const { + if (value.empty()) { + value = direction + " "; + bool first = true; + for (auto &rid : rids) { + if (first) { + first = false; + } else { + value += ';'; + } + value += rid; + } + } + return SdpItem::toString(); +} + +void SdpAttrRid::parse(const string &str) { + auto vec = split(str, " "); + CHECK(vec.size() >= 2); + rid = vec[0]; + direction = vec[1]; +} + +string SdpAttrRid::toString() const { + if (value.empty()) { + value = rid + " " + direction; + } + return SdpItem::toString(); +} + +void RtcSession::loadFrom(const string &str) { + RtcSessionSdp sdp; + sdp.parse(str); + + version = sdp.getVersion(); + origin = sdp.getOrigin(); + session_name = sdp.getSessionName(); + session_info = sdp.getSessionInfo(); + connection = sdp.getConnection(); + time = sdp.getSessionTime(); + msid_semantic = sdp.getItemClass('a', "msid-semantic"); + for (auto &media : sdp.medias) { + auto mline = media.getItemClass('m'); + this->media.emplace_back(); + auto &rtc_media = this->media.back(); + rtc_media.mid = media.getStringItem('a', "mid"); + rtc_media.proto = mline.proto; + rtc_media.type = mline.type; + rtc_media.port = mline.port; + rtc_media.addr = media.getItemClass('c'); + rtc_media.bandwidth = media.getItemClass('b'); + rtc_media.ice_ufrag = media.getStringItem('a', "ice-ufrag"); + rtc_media.ice_pwd = media.getStringItem('a', "ice-pwd"); + rtc_media.role = media.getItemClass('a', "setup").role; + rtc_media.fingerprint = media.getItemClass('a', "fingerprint"); + if (rtc_media.fingerprint.empty()) { + rtc_media.fingerprint = sdp.getItemClass('a', "fingerprint"); + } + rtc_media.ice_lite = media.getItem('a', "ice-lite").operator bool(); + auto ice_options = media.getItemClass('a', "ice-options"); + rtc_media.ice_trickle = ice_options.trickle; + rtc_media.ice_renomination = ice_options.renomination; + rtc_media.candidate = media.getAllItem('a', "candidate"); + + if (mline.type == TrackType::TrackApplication) { + rtc_media.sctp_port = atoi(media.getStringItem('a', "sctp-port").data()); + rtc_media.sctpmap = media.getItemClass('a', "sctpmap"); + continue; + } + rtc_media.rtcp_addr = media.getItemClass('a', "rtcp"); + rtc_media.direction = media.getDirection(); + rtc_media.extmap = media.getAllItem('a', "extmap"); + rtc_media.rtcp_mux = media.getItem('a', "rtcp-mux").operator bool(); + rtc_media.rtcp_rsize = media.getItem('a', "rtcp-rsize").operator bool(); + + map rtc_ssrc_map; + auto ssrc_attr = media.getAllItem('a', "ssrc"); + for (auto &ssrc : ssrc_attr) { + auto &rtc_ssrc = rtc_ssrc_map[ssrc.ssrc]; + rtc_ssrc.ssrc = ssrc.ssrc; + if (!strcasecmp(ssrc.attribute.data(), "cname")) { + rtc_ssrc.cname = ssrc.attribute_value; + continue; + } + if (!strcasecmp(ssrc.attribute.data(), "msid")) { + rtc_ssrc.msid = ssrc.attribute_value; + continue; + } + if (!strcasecmp(ssrc.attribute.data(), "mslabel")) { + rtc_ssrc.mslabel = ssrc.attribute_value; + continue; + } + if (!strcasecmp(ssrc.attribute.data(), "label")) { + rtc_ssrc.label = ssrc.attribute_value; + continue; + } + } + + auto ssrc_groups = media.getAllItem('a', "ssrc-group"); + bool have_rtx_ssrc = false; + SdpAttrSSRCGroup *ssrc_group_sim = nullptr; + for (auto &group : ssrc_groups) { + if (group.isFID()) { + have_rtx_ssrc = true; + // ssrc-group:FID字段必须包含rtp/rtx的ssrc [AUTO-TRANSLATED:3da97d7d] + // The ssrc-group:FID field must contain the SSRCs of rtp/rtx + CHECK(group.ssrcs.size() == 2); + // 根据rtp ssrc找到对象 [AUTO-TRANSLATED:c0a56b42] + // Find the object based on the RTP SSRC + auto it = rtc_ssrc_map.find(group.ssrcs[0]); + CHECK(it != rtc_ssrc_map.end()); + // 设置rtx ssrc [AUTO-TRANSLATED:422e2a55] + // Set the RTX SSRC + it->second.rtx_ssrc = group.ssrcs[1]; + rtc_media.rtp_rtx_ssrc.emplace_back(it->second); + } else if (group.isSIM()) { + CHECK(!ssrc_group_sim); + ssrc_group_sim = &group; + } + } + + if (!have_rtx_ssrc) { + // 按照sdp顺序依次添加ssrc [AUTO-TRANSLATED:0996ba7e] + // Add SSRCs in the order of SDP + for (auto &attr : ssrc_attr) { + if (attr.attribute == "cname") { + rtc_media.rtp_rtx_ssrc.emplace_back(rtc_ssrc_map[attr.ssrc]); + } + } + } + + auto simulcast = media.getItemClass('a', "simulcast"); + if (!simulcast.empty()) { + // a=rid:h send + // a=rid:m send + // a=rid:l send + // a=simulcast:send h;m;l + // 风格的simulcast [AUTO-TRANSLATED:94ac2d55] + // Style of simulcast + unordered_set rid_map; + for (auto &rid : simulcast.rids) { + rid_map.emplace(rid); + } + for (auto &rid : media.getAllItem('a', "rid")) { + CHECK(rid.direction == simulcast.direction); + CHECK(rid_map.find(rid.rid) != rid_map.end()); + } + // simulcast最少要求2种方案 [AUTO-TRANSLATED:31732a7a] + // Simulcast requires at least 2 schemes + CHECK(simulcast.rids.size() >= 2); + rtc_media.rtp_rids = simulcast.rids; + } + + if (ssrc_group_sim) { + // 指定了a=ssrc-group:SIM [AUTO-TRANSLATED:5732661e] + // a=ssrc-group:SIM is specified + for (auto ssrc : ssrc_group_sim->ssrcs) { + auto it = rtc_ssrc_map.find(ssrc); + CHECK(it != rtc_ssrc_map.end()); + rtc_media.rtp_ssrc_sim.emplace_back(it->second); + } + } else if (!rtc_media.rtp_rids.empty()) { + // 未指定a=ssrc-group:SIM, 但是指定了a=simulcast, 那么只能根据ssrc顺序来对应rid顺序 [AUTO-TRANSLATED:b198a817] + // a=ssrc-group:SIM is not specified, but a=simulcast is specified, so the RID order can only be matched according to the SSRC order + rtc_media.rtp_ssrc_sim = rtc_media.rtp_rtx_ssrc; + } + + if (!rtc_media.supportSimulcast()) { + // 不支持simulcast的情况下,最多一组ssrc [AUTO-TRANSLATED:3ea8ed65] + // In the case of not supporting simulcast, there is at most one group of SSRCs + CHECK(rtc_media.rtp_rtx_ssrc.size() <= 1); + } else { + // simulcast的情况下,要么没有指定ssrc,要么指定的ssrc个数与rid个数一致 [AUTO-TRANSLATED:1d45ce03] + // In the case of simulcast, either no SSRC is specified or the number of specified SSRCs is consistent with the number of RIDs + // CHECK(rtc_media.rtp_ssrc_sim.empty() || rtc_media.rtp_ssrc_sim.size() == rtc_media.rtp_rids.size()); + } + + auto rtpmap_arr = media.getAllItem('a', "rtpmap"); + auto rtcpfb_arr = media.getAllItem('a', "rtcp-fb"); + auto fmtp_aar = media.getAllItem('a', "fmtp"); + // 方便根据pt查找rtpmap,一个pt必有一条 [AUTO-TRANSLATED:c3673faa] + // Convenient to find rtpmap based on pt, one pt must have one + map rtpmap_map; + // 方便根据pt查找rtcp-fb,一个pt可能有多条或0条 [AUTO-TRANSLATED:38361f68] + // Convenient to find rtcp-fb based on pt, one pt may have multiple or 0 + multimap rtcpfb_map; + // 方便根据pt查找fmtp,一个pt最多一条 [AUTO-TRANSLATED:be5d652d] + // Convenient to find fmtp based on pt, one pt has at most one + map fmtp_map; + + for (auto &rtpmap : rtpmap_arr) { + // 添加失败,有多条 [AUTO-TRANSLATED:717782c0] + // Add failed, there are multiple + CHECK(rtpmap_map.emplace(rtpmap.pt, rtpmap).second, "该pt存在多条a=rtpmap:", (int)rtpmap.pt); + } + for (auto &rtpfb : rtcpfb_arr) { + rtcpfb_map.emplace(rtpfb.pt, rtpfb); + } + for (auto &fmtp : fmtp_aar) { + // 添加失败,有多条 [AUTO-TRANSLATED:717782c0] + // Add failed, there are multiple + CHECK(fmtp_map.emplace(fmtp.pt, fmtp).second, "该pt存在多条a=fmtp:", (int)fmtp.pt); + } + for (auto &item : mline.fmts) { + auto pt = atoi(item.c_str()); + CHECK(pt < 0xFF, "invalid payload type: ", item); + // 遍历所有编码方案的pt [AUTO-TRANSLATED:40f2db36] + // Traverse the pt of all encoding schemes + rtc_media.plan.emplace_back(); + auto &plan = rtc_media.plan.back(); + auto rtpmap_it = rtpmap_map.find(pt); + if (rtpmap_it == rtpmap_map.end()) { + plan.pt = pt; + plan.codec = RtpPayload::getName(pt); + plan.sample_rate = RtpPayload::getClockRate(pt); + plan.channel = RtpPayload::getAudioChannel(pt); + } else { + plan.pt = rtpmap_it->second.pt; + plan.codec = rtpmap_it->second.codec; + plan.sample_rate = rtpmap_it->second.sample_rate; + plan.channel = rtpmap_it->second.channel; + } + + auto fmtp_it = fmtp_map.find(pt); + if (fmtp_it != fmtp_map.end()) { + plan.fmtp = fmtp_it->second.fmtp; + } + for (auto rtpfb_it = rtcpfb_map.find(pt); rtpfb_it != rtcpfb_map.end() && rtpfb_it->second.pt == pt; ++rtpfb_it) { + plan.rtcp_fb.emplace(rtpfb_it->second.rtcp_type); + } + } + } + + group = sdp.getItemClass('a', "group"); +} + +void RtcSdpBase::toRtsp() { + for (auto it = items.begin(); it != items.end();) { + switch ((*it)->getKey()[0]) { + case 'v': + case 'o': + case 's': + case 'i': + case 't': + case 'c': + case 'b': { + ++it; + break; + } + + case 'm': { + auto m = dynamic_pointer_cast(*it); + CHECK(m); + m->proto = "RTP/AVP"; + ++it; + break; + } + case 'a': { + auto attr = dynamic_pointer_cast(*it); + CHECK(attr); + if (!strcasecmp(attr->detail->getKey(), "rtpmap") || !strcasecmp(attr->detail->getKey(), "fmtp")) { + ++it; + break; + } + } + default: { + it = items.erase(it); + break; + } + } + } +} + +string RtcSession::toRtspSdp() const { + RtcSession copy = *this; + copy.media.clear(); + for (auto &m : media) { + switch (m.type) { + case TrackAudio: + case TrackVideo: { + if (m.direction != RtpDirection::inactive) { + copy.media.emplace_back(m); + copy.media.back().plan.resize(1); + } + break; + } + default: continue; + } + } + + CHECK(!copy.media.empty()); + auto sdp = copy.toRtcSessionSdp(); + sdp->toRtsp(); + int i = 0; + for (auto &m : sdp->medias) { + m.toRtsp(); + m.addAttr(std::make_shared("control", string("trackID=") + to_string(i++))); + } + return sdp->toString(); +} + +void addSdpAttrSSRC(const RtcSSRC &rtp_ssrc, RtcSdpBase &media, uint32_t ssrc_num) { + assert(ssrc_num); + SdpAttrSSRC ssrc; + ssrc.ssrc = ssrc_num; + + ssrc.attribute = "cname"; + ssrc.attribute_value = rtp_ssrc.cname; + media.addAttr(std::make_shared(ssrc)); + + if (!rtp_ssrc.msid.empty()) { + ssrc.attribute = "msid"; + ssrc.attribute_value = rtp_ssrc.msid; + media.addAttr(std::make_shared(ssrc)); + } + + if (!rtp_ssrc.mslabel.empty()) { + ssrc.attribute = "mslabel"; + ssrc.attribute_value = rtp_ssrc.mslabel; + media.addAttr(std::make_shared(ssrc)); + } + + if (!rtp_ssrc.label.empty()) { + ssrc.attribute = "label"; + ssrc.attribute_value = rtp_ssrc.label; + media.addAttr(std::make_shared(ssrc)); + } +} + +RtcSessionSdp::Ptr RtcSession::toRtcSessionSdp() const { + RtcSessionSdp::Ptr ret = std::make_shared(); + auto &sdp = *ret; + sdp.addItem(std::make_shared>(to_string(version))); + sdp.addItem(std::make_shared(origin)); + sdp.addItem(std::make_shared>(session_name)); + if (!session_info.empty()) { + sdp.addItem(std::make_shared>(session_info)); + } + sdp.addItem(std::make_shared(time)); + if (connection.empty()) { + sdp.addItem(std::make_shared(connection)); + } + sdp.addAttr(std::make_shared(group)); + sdp.addAttr(std::make_shared()); + sdp.addAttr(std::make_shared(msid_semantic)); + + bool ice_lite = false; + + for (auto &m : media) { + sdp.medias.emplace_back(); + auto &sdp_media = sdp.medias.back(); + auto mline = std::make_shared(); + mline->type = m.type; + mline->port = m.port; + mline->proto = m.proto; + for (auto &p : m.plan) { + mline->fmts.emplace_back(to_string((int)p.pt)); + } + if (m.type == TrackApplication) { + mline->fmts.emplace_back("webrtc-datachannel"); + } + sdp_media.addItem(std::move(mline)); + sdp_media.addItem(std::make_shared(m.addr)); + if (!m.bandwidth.empty() && m.type != TrackAudio) { + sdp_media.addItem(std::make_shared(m.bandwidth)); + } + if (!m.rtcp_addr.empty()) { + sdp_media.addAttr(std::make_shared(m.rtcp_addr)); + } + + sdp_media.addAttr(std::make_shared(m.ice_ufrag)); + sdp_media.addAttr(std::make_shared(m.ice_pwd)); + if (m.ice_trickle || m.ice_renomination) { + auto attr = std::make_shared(); + attr->trickle = m.ice_trickle; + attr->renomination = m.ice_renomination; + sdp_media.addAttr(attr); + } + sdp_media.addAttr(std::make_shared(m.fingerprint)); + sdp_media.addAttr(std::make_shared(m.role)); + sdp_media.addAttr(std::make_shared(m.mid)); + if (m.ice_lite) { + sdp_media.addAttr(std::make_shared("ice-lite")); + ice_lite = true; + } + for (auto &ext : m.extmap) { + sdp_media.addAttr(std::make_shared(ext)); + } + if (m.direction != RtpDirection::invalid) { + sdp_media.addAttr(std::make_shared(m.direction)); + } + if (m.rtcp_mux) { + sdp_media.addAttr(std::make_shared("rtcp-mux")); + } + if (m.rtcp_rsize) { + sdp_media.addAttr(std::make_shared("rtcp-rsize")); + } + + if (m.type != TrackApplication) { + for (auto &p : m.plan) { + auto rtp_map = std::make_shared(); + rtp_map->pt = p.pt; + rtp_map->codec = p.codec; + rtp_map->sample_rate = p.sample_rate; + rtp_map->channel = p.channel; + // 添加a=rtpmap [AUTO-TRANSLATED:8bef5d64] + // Add a=rtpmap + sdp_media.addAttr(std::move(rtp_map)); + + for (auto &fb : p.rtcp_fb) { + auto rtcp_fb = std::make_shared(); + rtcp_fb->pt = p.pt; + rtcp_fb->rtcp_type = fb; + // 添加a=rtcp-fb [AUTO-TRANSLATED:11754b43] + // Add a=rtcp-fb + sdp_media.addAttr(std::move(rtcp_fb)); + } + + if (!p.fmtp.empty()) { + auto fmtp = std::make_shared(); + fmtp->pt = p.pt; + fmtp->fmtp = p.fmtp; + // 添加a=fmtp [AUTO-TRANSLATED:594a4425] + // Add a=fmtp + sdp_media.addAttr(std::move(fmtp)); + } + } + + { + // 添加a=msid字段 [AUTO-TRANSLATED:cf2c1471] + // Add a=msid field + if (!m.rtp_rtx_ssrc.empty()) { + if (!m.rtp_rtx_ssrc[0].msid.empty()) { + auto msid = std::make_shared(); + msid->parse(m.rtp_rtx_ssrc[0].msid); + sdp_media.addAttr(std::move(msid)); + } + } + } + + { + for (auto &ssrc : m.rtp_rtx_ssrc) { + // 添加a=ssrc字段 [AUTO-TRANSLATED:75ca5225] + // Add a=ssrc field + CHECK(!ssrc.empty()); + addSdpAttrSSRC(ssrc, sdp_media, ssrc.ssrc); + if (ssrc.rtx_ssrc) { + addSdpAttrSSRC(ssrc, sdp_media, ssrc.rtx_ssrc); + + // 生成a=ssrc-group:FID字段 [AUTO-TRANSLATED:22b1f966] + // Generate a=ssrc-group:FID field + // 有rtx ssrc [AUTO-TRANSLATED:fece8076] + // There is rtx ssrc + auto group = std::make_shared(); + group->type = "FID"; + group->ssrcs.emplace_back(ssrc.ssrc); + group->ssrcs.emplace_back(ssrc.rtx_ssrc); + sdp_media.addAttr(std::move(group)); + } + } + } + + { + if (m.rtp_ssrc_sim.size() >= 2) { + // simulcast 要求 2~3路 [AUTO-TRANSLATED:3237ffca] + // Simulcast requires 2~3 channels + auto group = std::make_shared(); + for (auto &ssrc : m.rtp_ssrc_sim) { + group->ssrcs.emplace_back(ssrc.ssrc); + } + // 添加a=ssrc-group:SIM字段 [AUTO-TRANSLATED:46b04aae] + // Add a=ssrc-group:SIM field + group->type = "SIM"; + sdp_media.addAttr(std::move(group)); + } + + if (m.rtp_rids.size() >= 2) { + auto simulcast = std::make_shared(); + simulcast->direction = "recv"; + simulcast->rids = m.rtp_rids; + sdp_media.addAttr(std::move(simulcast)); + + for (auto &rid : m.rtp_rids) { + auto attr_rid = std::make_shared(); + attr_rid->rid = rid; + attr_rid->direction = "recv"; + sdp_media.addAttr(std::move(attr_rid)); + } + } + } + + } else { + if (!m.sctpmap.empty()) { + sdp_media.addAttr(std::make_shared(m.sctpmap)); + } + sdp_media.addAttr(std::make_shared("sctp-port", to_string(m.sctp_port))); + } + + for (auto &cand : m.candidate) { + if (cand.port) { + sdp_media.addAttr(std::make_shared(cand)); + } + } + } + if (ice_lite) { + sdp.addAttr(std::make_shared("ice-lite")); + } + return ret; +} + +string RtcSession::toString() const { + return toRtcSessionSdp()->toString(); +} + +string RtcCodecPlan::getFmtp(const char *key) const { + for (auto &item : fmtp) { + if (strcasecmp(item.first.data(), key) == 0) { + return item.second; + } + } + return ""; +} + +const RtcCodecPlan *RtcMedia::getPlan(uint8_t pt) const { + for (auto &item : plan) { + if (item.pt == pt) { + return &item; + } + } + return nullptr; +} + +const RtcCodecPlan *RtcMedia::getPlan(const char *codec) const { + for (auto &item : plan) { + if (strcasecmp(item.codec.data(), codec) == 0) { + return &item; + } + } + return nullptr; +} + +const RtcCodecPlan *RtcMedia::getRelatedRtxPlan(uint8_t pt) const { + for (auto &item : plan) { + if (strcasecmp(item.codec.data(), "rtx") == 0) { + auto apt = atoi(item.getFmtp("apt").data()); + if (pt == apt) { + return &item; + } + } + } + return nullptr; +} + +uint32_t RtcMedia::getRtpSSRC() const { + if (rtp_rtx_ssrc.size()) { + return rtp_rtx_ssrc[0].ssrc; + } + return 0; +} + +uint32_t RtcMedia::getRtxSSRC() const { + if (rtp_rtx_ssrc.size()) { + return rtp_rtx_ssrc[0].rtx_ssrc; + } + return 0; +} + +bool RtcMedia::supportSimulcast() const { + if (!rtp_rids.empty()) { + return true; + } + if (!rtp_ssrc_sim.empty()) { + return true; + } + return false; +} + +void RtcMedia::checkValid() const { + CHECK(type != TrackInvalid); + CHECK(!mid.empty()); + CHECK(!proto.empty()); + CHECK(direction != RtpDirection::invalid || type == TrackApplication); + CHECK(!plan.empty() || type == TrackApplication); + CHECK(type == TrackApplication || rtcp_mux, "只支持rtcp-mux模式"); + + bool send_rtp = (direction == RtpDirection::sendonly || direction == RtpDirection::sendrecv); + if (!supportSimulcast()) { + // 非simulcast时,检查有没有指定rtp ssrc [AUTO-TRANSLATED:e2d53f8a] + // When not simulcast, check if the RTP SSRC is specified + CHECK(!rtp_rtx_ssrc.empty() || !send_rtp); + + for (auto ssrc : rtp_rtx_ssrc) { + InfoL << "ssrc:" << ssrc.cname << "," << ssrc.msid; + } + } + +#if 0 + // todo 发现Firefox(88.0)在mac平台下,开启rtx后没有指定ssrc [AUTO-TRANSLATED:9112d91a] + // todo Found that Firefox (88.0) on the mac platform does not specify ssrc when rtx is enabled + auto rtx_plan = getPlan("rtx"); + if (rtx_plan) { + // 开启rtx后必须指定rtx_ssrc [AUTO-TRANSLATED:c527f68d] + // RTX must be specified after rtx_ssrc is enabled + CHECK(rtp_rtx_ssrc.size() >= 2 || !send_rtp); + } +#endif +} + +void RtcSession::checkValid() const { + CHECK(version == 0); + CHECK(!origin.empty()); + CHECK(!session_name.empty()); + CHECK(!msid_semantic.empty()); + CHECK(!media.empty()); + CHECK(!group.mids.empty() && group.mids.size() <= media.size(), "只支持group BUNDLE模式"); + + bool have_active_media = false; + for (auto &item : media) { + item.checkValid(); + + if (TrackApplication == item.type) { + have_active_media = true; + } + switch (item.direction) { + case RtpDirection::sendrecv: + case RtpDirection::sendonly: + case RtpDirection::recvonly: have_active_media = true; break; + default: break; + } + } + CHECK(have_active_media, "必须确保最少有一个活跃的track"); +} + +const RtcMedia *RtcSession::getMedia(TrackType type) const { + for (auto &m : media) { + if (m.type == type) { + return &m; + } + } + return nullptr; +} + +bool RtcSession::supportRtcpFb(const string &name, TrackType type) const { + auto media = getMedia(type); + if (!media) { + return false; + } + auto &ref = media->plan[0].rtcp_fb; + return ref.find(name) != ref.end(); +} + +bool RtcSession::supportSimulcast() const { + for (auto &m : media) { + if (m.supportSimulcast()) { + return true; + } + } + return false; +} + +bool RtcSession::isOnlyDatachannel() const { + return 1 == media.size() && TrackApplication == media[0].type; +} + +string const SdpConst::kTWCCRtcpFb = "transport-cc"; +string const SdpConst::kRembRtcpFb = "goog-remb"; + +void RtcConfigure::RtcTrackConfigure::enableTWCC(bool enable) { + if (!enable) { + rtcp_fb.erase(SdpConst::kTWCCRtcpFb); + extmap.erase(RtpExtType::transport_cc); + } else { + rtcp_fb.emplace(SdpConst::kTWCCRtcpFb); + extmap.emplace(RtpExtType::transport_cc, RtpDirection::sendrecv); + } +} + +void RtcConfigure::RtcTrackConfigure::enableREMB(bool enable) { + if (!enable) { + rtcp_fb.erase(SdpConst::kRembRtcpFb); + extmap.erase(RtpExtType::abs_send_time); + } else { + rtcp_fb.emplace(SdpConst::kRembRtcpFb); + extmap.emplace(RtpExtType::abs_send_time, RtpDirection::sendrecv); + } +} + +static vector toCodecArray(const string &str) { + vector ret; + auto vec = split(str, ","); + for (auto &s : vec) { + auto codec = getCodecId(trim(s)); + if (codec != CodecInvalid) { + ret.emplace_back(codec); + } + } + return ret; +} + +void RtcConfigure::RtcTrackConfigure::setDefaultSetting(TrackType type) { + rtcp_mux = true; + rtcp_rsize = false; + group_bundle = true; + support_rtx = true; + support_red = false; + support_ulpfec = false; + ice_lite = true; + ice_trickle = true; + ice_renomination = false; + switch (type) { + case TrackAudio: { + // 此处调整偏好的编码格式优先级 [AUTO-TRANSLATED:b8719e66] + // Adjust the priority of preferred encoding formats here + GET_CONFIG_FUNC(vector, s_preferred_codec, Rtc::kPreferredCodecA, toCodecArray); + CHECK(!s_preferred_codec.empty(), "rtc音频偏好codec不能为空"); + preferred_codec = s_preferred_codec; + + rtcp_fb = { SdpConst::kTWCCRtcpFb, SdpConst::kRembRtcpFb }; + extmap = { { RtpExtType::ssrc_audio_level, RtpDirection::sendrecv }, + { RtpExtType::csrc_audio_level, RtpDirection::sendrecv }, + { RtpExtType::abs_send_time, RtpDirection::sendrecv }, + { RtpExtType::transport_cc, RtpDirection::sendrecv }, + // rtx重传rtp时,忽略sdes_mid类型的rtp ext,实测发现Firefox在接收rtx时,如果存在sdes_mid的ext,将导致无法播放 [AUTO-TRANSLATED:221df025] + // When rtx retransmits rtp, ignore the rtp ext of sdes_mid type. It is found that Firefox cannot play when receiving rtx if there is an ext of sdes_mid + //{RtpExtType::sdes_mid,RtpDirection::sendrecv}, + { RtpExtType::sdes_rtp_stream_id, RtpDirection::sendrecv }, + { RtpExtType::sdes_repaired_rtp_stream_id, RtpDirection::sendrecv } }; + break; + } + case TrackVideo: { + // 此处调整偏好的编码格式优先级 [AUTO-TRANSLATED:b8719e66] + // Adjust the priority of preferred encoding formats here + GET_CONFIG_FUNC(vector, s_preferred_codec, Rtc::kPreferredCodecV, toCodecArray); + CHECK(!s_preferred_codec.empty(), "rtc视频偏好codec不能为空"); + preferred_codec = s_preferred_codec; + + rtcp_fb = { SdpConst::kTWCCRtcpFb, SdpConst::kRembRtcpFb, "nack", "ccm fir", "nack pli" }; + extmap = { { RtpExtType::abs_send_time, RtpDirection::sendrecv }, + { RtpExtType::transport_cc, RtpDirection::sendrecv }, + // rtx重传rtp时,忽略sdes_mid类型的rtp ext,实测发现Firefox在接收rtx时,如果存在sdes_mid的ext,将导致无法播放 [AUTO-TRANSLATED:221df025] + // When rtx retransmits rtp, ignore the rtp ext of sdes_mid type. It is found that Firefox cannot play when receiving rtx if there is an ext of sdes_mid + //{RtpExtType::sdes_mid,RtpDirection::sendrecv}, + { RtpExtType::sdes_rtp_stream_id, RtpDirection::sendrecv }, + { RtpExtType::sdes_repaired_rtp_stream_id, RtpDirection::sendrecv }, + { RtpExtType::video_timing, RtpDirection::sendrecv }, + { RtpExtType::color_space, RtpDirection::sendrecv }, + { RtpExtType::video_content_type, RtpDirection::sendrecv }, + { RtpExtType::playout_delay, RtpDirection::sendrecv }, + // 手机端推webrtc 会带有旋转角度,rtc协议能正常播放 其他协议拉流画面旋转 [AUTO-TRANSLATED:3f2f9e0e] + // Mobile push webrtc will have a rotation angle, rtc protocol can play normally, other protocols pull stream picture rotation + //{RtpExtType::video_orientation, RtpDirection::sendrecv}, + { RtpExtType::toffset, RtpDirection::sendrecv }, + { RtpExtType::framemarking, RtpDirection::sendrecv } }; + break; + } + case TrackApplication: { + break; + } + default: break; + } +} + +void RtcConfigure::setDefaultSetting(string ice_ufrag, string ice_pwd, RtpDirection direction, const SdpAttrFingerprint &fingerprint) { + video.setDefaultSetting(TrackVideo); + audio.setDefaultSetting(TrackAudio); + application.setDefaultSetting(TrackApplication); + + video.ice_ufrag = audio.ice_ufrag = application.ice_ufrag = std::move(ice_ufrag); + video.ice_pwd = audio.ice_pwd = application.ice_pwd = std::move(ice_pwd); + video.direction = audio.direction = application.direction = direction; + video.fingerprint = audio.fingerprint = application.fingerprint = fingerprint; +} + +void RtcConfigure::addCandidate(const SdpAttrCandidate &candidate, TrackType type) { + switch (type) { + case TrackAudio: { + audio.candidate.emplace_back(candidate); + break; + } + case TrackVideo: { + video.candidate.emplace_back(candidate); + break; + } + case TrackApplication: { + application.candidate.emplace_back(candidate); + break; + } + default: { + if (audio.group_bundle) { + audio.candidate.emplace_back(candidate); + } + if (video.group_bundle) { + video.candidate.emplace_back(candidate); + } + if (application.group_bundle) { + application.candidate.emplace_back(candidate); + } + break; + } + } +} + +void RtcConfigure::enableTWCC(bool enable, TrackType type) { + switch (type) { + case TrackAudio: { + audio.enableTWCC(enable); + break; + } + case TrackVideo: { + video.enableTWCC(enable); + break; + } + default: { + audio.enableTWCC(enable); + video.enableTWCC(enable); + break; + } + } +} + +void RtcConfigure::enableREMB(bool enable, TrackType type) { + switch (type) { + case TrackAudio: { + audio.enableREMB(enable); + break; + } + case TrackVideo: { + video.enableREMB(enable); + break; + } + default: { + audio.enableREMB(enable); + video.enableREMB(enable); + break; + } + } +} + +shared_ptr RtcConfigure::createOffer() const { + shared_ptr ret = std::make_shared(); + ret->version = 0; + ret->origin.session_id = std::to_string(makeRandNum()); + ret->origin.session_version = std::to_string(1); + ret->session_name = "-"; + + createMediaOffer(ret); + // 设置音视频端口复用 [AUTO-TRANSLATED:ffe27d17] + // Set audio and video port multiplexing + for (auto &m : ret->media) { + // The remote end has rejected (port 0) the m-section, so it should not be putting its mid in the group attribute. + if (m.port) { + ret->group.mids.emplace_back(m.mid); + } + } + + return ret; +} + +shared_ptr RtcConfigure::createAnswer(const RtcSession &offer) const { + shared_ptr ret = std::make_shared(); + ret->version = offer.version; + ret->origin = offer.origin; + ret->session_name = offer.session_name; + ret->msid_semantic = offer.msid_semantic; + + for (auto &m : offer.media) { + matchMedia(ret, m); + } + + // 设置音视频端口复用 [AUTO-TRANSLATED:ffe27d17] + // Set audio and video port multiplexing + if (!offer.group.mids.empty()) { + for (auto &m : ret->media) { + // The remote end has rejected (port 0) the m-section, so it should not be putting its mid in the group attribute. + if (m.port) { + ret->group.mids.emplace_back(m.mid); + } + } + } + return ret; +} + +static RtpDirection matchDirection(RtpDirection offer_direction, RtpDirection supported) { + switch (offer_direction) { + case RtpDirection::sendonly: { + if (supported != RtpDirection::recvonly && supported != RtpDirection::sendrecv) { + // 我们不支持接收 [AUTO-TRANSLATED:e4ef4034] + // We do not support receiving + return RtpDirection::inactive; + } + return RtpDirection::recvonly; + } + + case RtpDirection::recvonly: { + if (supported != RtpDirection::sendonly && supported != RtpDirection::sendrecv) { + // 我们不支持发送 [AUTO-TRANSLATED:6505a226] + // We do not support sending + return RtpDirection::inactive; + } + return RtpDirection::sendonly; + } + + // 对方支持发送接收,那么最终能力根据配置来决定 [AUTO-TRANSLATED:d234d603] + // The other party supports sending and receiving, so the final capability is determined by the configuration + case RtpDirection::sendrecv: return (supported == RtpDirection::invalid ? RtpDirection::inactive : supported); + case RtpDirection::inactive: return RtpDirection::inactive; + default: return RtpDirection::invalid; + } +} + +static DtlsRole mathDtlsRole(DtlsRole role) { + switch (role) { + case DtlsRole::actpass: + case DtlsRole::active: return DtlsRole::passive; + case DtlsRole::passive: return DtlsRole::active; + default: CHECK(0, "invalid role:", getDtlsRoleString(role)); return DtlsRole::passive; + } +} + +void RtcConfigure::createMediaOffer(const std::shared_ptr &ret) const { + int index = 0; + if (video.direction != RtpDirection::sendonly || _rtsp_video_plan) { + createMediaOfferEach(ret, TrackVideo, index++); + } + + if (audio.direction != RtpDirection::sendonly || _rtsp_audio_plan) { + createMediaOfferEach(ret, TrackAudio, index++); + } + // createMediaOfferEach(ret, TrackApplication, index++); +} + +void RtcConfigure::createMediaOfferEach(const std::shared_ptr &ret, TrackType type, int index) const { + // rtpmap + static std::multimap audio_list_ref, video_list_ref; + static toolkit::onceToken token([]() { + audio_list_ref.emplace(CodecG711U, make_shared("PCMU", 0, 8000)); + audio_list_ref.emplace(CodecG711A, make_shared("PCMA", 8, 8000)); + audio_list_ref.emplace(CodecOpus, make_shared("opus", 111, 48000)); + audio_list_ref.emplace(CodecAAC, make_shared("mpeg4-generic", 96, 48000)); + + video_list_ref.emplace(CodecH264, make_shared(102, 90000, PROFILE_H264_BASELINE)); + video_list_ref.emplace(CodecH264, make_shared(104, 90000, PROFILE_H264_MAIN)); + video_list_ref.emplace(CodecH264, make_shared(106, 90000, PROFILE_H264_HIGH)); + video_list_ref.emplace(CodecH265, make_shared(120, 90000, PROFILE_H265_MAIN)); + video_list_ref.emplace(CodecH265, make_shared(124, 90000, PROFILE_H265_MAIN10)); + video_list_ref.emplace(CodecH265, make_shared(126, 90000, PROFILE_H265_SCREEN)); + video_list_ref.emplace(CodecAV1, make_shared(35, 90000, 0)); + video_list_ref.emplace(CodecVP8, make_shared("VP8", 96, 90000)); + video_list_ref.emplace(CodecVP9, make_shared(98, 90000, 0)); + video_list_ref.emplace(CodecVP9, make_shared(100, 90000, 2)); + }); + + bool check_profile = true; + bool check_codec = true; + const RtcTrackConfigure *cfg_ptr = nullptr; + std::multimap* rtpMap = nullptr; + switch (type) { + case TrackAudio: cfg_ptr = &audio; rtpMap = &audio_list_ref; break; + case TrackVideo: cfg_ptr = &video; rtpMap = &video_list_ref; break; + case TrackApplication: cfg_ptr = &application; break; + default: return; + } + auto &configure = *cfg_ptr; + + if (type == TrackApplication) { + RtcMedia media; + media.role = DtlsRole::active; + media.ice_ufrag = configure.ice_ufrag; + media.ice_pwd = configure.ice_pwd; + media.fingerprint = configure.fingerprint; + // media.ice_lite = configure.ice_lite; + media.ice_lite = false; +#ifdef ENABLE_SCTP + media.candidate = configure.candidate; +#else + media.port = 9; //占位符,表示后续协商分配 + WarnL << "answer sdp忽略application mline, 请安装usrsctp后再测试datachannel功能"; +#endif + ret->media.emplace_back(media); + return; + } + + RtcMedia media; + media.type = type; + media.mid = to_string(index); + media.proto = "UDP/TLS/RTP/SAVPF"; + media.port = 9;//占位符,表示后续协商分配 + // media.addr = ; + // media.bandwidth = ; + // media.rtcp_addr = ; + media.rtcp_mux = true; + media.rtcp_rsize = true; + media.ice_trickle = true; + media.ice_renomination = configure.ice_renomination; + media.ice_ufrag = configure.ice_ufrag; + media.ice_pwd = configure.ice_pwd; + media.fingerprint = configure.fingerprint; + // media.ice_lite = configure.ice_lite; + media.ice_lite = false; + // candidate offer不生成candidate,反正也是错的 + // media.candidate = configure.candidate; + // copy simulicast setting + // media.rtp_rids =; + // media.rtp_ssrc_sim = ; + + media.role = DtlsRole::active; + + // 如果codec匹配失败,那么禁用该track [AUTO-TRANSLATED:037de9a8] + // If the codec matching fails, then disable the track + media.direction = configure.direction; + + //extmap + for (auto extmap : cfg_ptr->extmap) { +#if 0 + if (extmap.second != media.direction) { + continue; + } +#endif + SdpAttrExtmap attrExtmap; + attrExtmap.direction = extmap.second; + attrExtmap.id = (uint8_t)extmap.first; + attrExtmap.ext = RtpExt::getExtUrl(extmap.first); + + media.extmap.push_back(attrExtmap); + } + + //rtpmap + for (auto codec : cfg_ptr->preferred_codec) { + if (!rtpMap) continue; + auto range = rtpMap->equal_range(codec); + for (auto it = range.first; it != range.second; ++it) { + auto rtpmap = it->second; + RtcCodecPlan plan; + plan.codec = rtpmap->getCodeName(); + plan.pt = rtpmap->getPayload(); + plan.sample_rate = rtpmap->getClockRate(); + plan.rtcp_fb = cfg_ptr->rtcp_fb; + auto fmtp = rtpmap->getFmtp(); + for (const auto& pair : fmtp) { + plan.fmtp.emplace(pair); + } + media.plan.push_back(plan); + // add video rtx plan + if (rtpmap->getType() == TrackVideo) { + // a=rtpmap:108 rtx/90000 + // a=fmtp:108 apt=107 + RtcCodecPlan rtx; + rtx.codec = "rtx"; + rtx.pt = rtpmap->getPayload() + 1; + rtx.sample_rate = rtpmap->getClockRate(); + rtx.fmtp["apt"] = std::to_string(rtpmap->getPayload()); + media.plan.push_back(rtx); + } + } + } + + //msid + if (media.direction != RtpDirection::recvonly) { + RtcSSRC ssrc; + ssrc.ssrc = (uint32_t)makeRandNum(); + ssrc.rtx_ssrc = (uint32_t)makeRandNum(); + ssrc.cname = makeRandStr(16); + ssrc.msid = makeRandStr(36) + " " + makeUuidStr(); + media.rtp_rtx_ssrc.push_back(ssrc); + } + + ret->media.emplace_back(media); +} + +void RtcConfigure::matchMedia(const std::shared_ptr &ret, const RtcMedia &offer_media) const { + bool check_profile = true; + bool check_codec = true; + const RtcTrackConfigure *cfg_ptr = nullptr; + + switch (offer_media.type) { + case TrackAudio: cfg_ptr = &audio; break; + case TrackVideo: cfg_ptr = &video; break; + case TrackApplication: cfg_ptr = &application; break; + default: return; + } + auto &configure = *cfg_ptr; + +RETRY: + + if (offer_media.type == TrackApplication) { + RtcMedia answer_media = offer_media; + answer_media.role = mathDtlsRole(offer_media.role); + answer_media.ice_ufrag = configure.ice_ufrag; + answer_media.ice_pwd = configure.ice_pwd; + answer_media.fingerprint = configure.fingerprint; + answer_media.ice_lite = configure.ice_lite; +#ifdef ENABLE_SCTP + answer_media.candidate = configure.candidate; +#else + answer_media.port = 0; + WarnL << "answer sdp忽略application mline, 请安装usrsctp后再测试datachannel功能"; +#endif + ret->media.emplace_back(answer_media); + return; + } + for (auto &codec : configure.preferred_codec) { + if (offer_media.ice_lite && configure.ice_lite) { + WarnL << "answer sdp配置为ice_lite模式,与offer sdp中的ice_lite模式冲突"; + continue; + } + const RtcCodecPlan *selected_plan = nullptr; + for (auto &plan : offer_media.plan) { + // 先检查编码格式是否为偏好 [AUTO-TRANSLATED:b7fb32a0] + // First check if the encoding format is preferred + if (check_codec && getCodecId(plan.codec) != codec) { + continue; + } + // 命中偏好的编码格式,然后检查规格 [AUTO-TRANSLATED:a859c839] + // Hit the preferred encoding format, then check the specifications + if (check_profile && !onCheckCodecProfile(plan, codec)) { + continue; + } + // 找到中意的codec [AUTO-TRANSLATED:4b5eebfd] + // Find the desired codec + selected_plan = &plan; + break; + } + if (!selected_plan) { + // offer中该媒体的所有的codec都不支持 [AUTO-TRANSLATED:3b57b86f] + // All codecs for this media in the offer are not supported + continue; + } + RtcMedia answer_media; + answer_media.type = offer_media.type; + answer_media.mid = offer_media.mid; + answer_media.proto = offer_media.proto; + answer_media.port = offer_media.port; + answer_media.addr = offer_media.addr; + answer_media.bandwidth = offer_media.bandwidth; + answer_media.rtcp_addr = offer_media.rtcp_addr; + answer_media.rtcp_mux = offer_media.rtcp_mux && configure.rtcp_mux; + answer_media.rtcp_rsize = offer_media.rtcp_rsize && configure.rtcp_rsize; + answer_media.ice_trickle = offer_media.ice_trickle && configure.ice_trickle; + answer_media.ice_renomination = offer_media.ice_renomination && configure.ice_renomination; + answer_media.ice_ufrag = configure.ice_ufrag; + answer_media.ice_pwd = configure.ice_pwd; + answer_media.fingerprint = configure.fingerprint; + answer_media.ice_lite = configure.ice_lite; + answer_media.candidate = configure.candidate; + // copy simulicast setting + answer_media.rtp_rids = offer_media.rtp_rids; + answer_media.rtp_ssrc_sim = offer_media.rtp_ssrc_sim; + + answer_media.role = mathDtlsRole(offer_media.role); + + // 如果codec匹配失败,那么禁用该track [AUTO-TRANSLATED:037de9a8] + // If the codec matching fails, then disable the track + answer_media.direction = check_codec ? matchDirection(offer_media.direction, configure.direction) : RtpDirection::inactive; + if (answer_media.direction == RtpDirection::invalid) { + continue; + } + if (answer_media.direction == RtpDirection::sendrecv) { + // 如果是收发双向,那么我们拷贝offer sdp的ssrc,确保ssrc一致 [AUTO-TRANSLATED:d4a621f2] + // If it is bidirectional, then we copy the offer sdp ssrc to ensure ssrc consistency + answer_media.rtp_rtx_ssrc = offer_media.rtp_rtx_ssrc; + } + + // 添加媒体plan [AUTO-TRANSLATED:3f730050] + // Add media plan + answer_media.plan.emplace_back(*selected_plan); + onSelectPlan(answer_media.plan.back(), codec); + + set pt_selected = { selected_plan->pt }; + + // 添加rtx,red,ulpfec plan [AUTO-TRANSLATED:1abff0c1] + // Add rtx, red, ulpfec plan + if (configure.support_red || configure.support_rtx || configure.support_ulpfec) { + for (auto &plan : offer_media.plan) { + if (!strcasecmp(plan.codec.data(), "rtx")) { + if (configure.support_rtx && atoi(plan.getFmtp("apt").data()) == selected_plan->pt) { + answer_media.plan.emplace_back(plan); + pt_selected.emplace(plan.pt); + } + continue; + } + if (!strcasecmp(plan.codec.data(), "red")) { + if (configure.support_red) { + answer_media.plan.emplace_back(plan); + pt_selected.emplace(plan.pt); + } + continue; + } + if (!strcasecmp(plan.codec.data(), "ulpfec")) { + if (configure.support_ulpfec) { + answer_media.plan.emplace_back(plan); + pt_selected.emplace(plan.pt); + } + continue; + } + } + } + + // 对方和我方都支持的扩展,那么我们才支持 [AUTO-TRANSLATED:a6cd98b2] + // We only support extensions that are supported by both the other party and us + for (auto &ext : offer_media.extmap) { + auto it = configure.extmap.find(RtpExt::getExtType(ext.ext)); + if (it != configure.extmap.end()) { + auto new_dir = matchDirection(ext.direction, it->second); + switch (new_dir) { + case RtpDirection::invalid: + case RtpDirection::inactive: continue; + default: break; + } + answer_media.extmap.emplace_back(ext); + answer_media.extmap.back().direction = new_dir; + } + } + + auto &rtcp_fb_ref = answer_media.plan[0].rtcp_fb; + rtcp_fb_ref.clear(); + // 对方和我方都支持的rtcpfb,那么我们才支持 [AUTO-TRANSLATED:f10450bb] + // We only support rtcpfb that is supported by both the other party and us + for (auto &fp : selected_plan->rtcp_fb) { + if (configure.rtcp_fb.find(fp) != configure.rtcp_fb.end()) { + // 对方该rtcp被我们支持 [AUTO-TRANSLATED:3b16e666] + // The other party's rtcp is supported by us + rtcp_fb_ref.emplace(fp); + } + } + +#if 0 + // todo 此处为添加无效的plan,webrtc sdp通过调节plan pt顺序选择匹配的codec,意味着后面的codec其实放在sdp中是无意义的 [AUTO-TRANSLATED:502d0cb2] + // todo This is to add an invalid plan. WebRTC sdp selects the matching codec by adjusting the plan pt order, which means that the subsequent codecs are actually meaningless in the sdp + for (auto &plan : offer_media.plan) { + if (pt_selected.find(plan.pt) == pt_selected.end()) { + answer_media.plan.emplace_back(plan); + } + } +#endif + ret->media.emplace_back(answer_media); + return; + } + + if (check_profile) { + // 如果是由于检查profile导致匹配失败,那么重试一次,且不检查profile [AUTO-TRANSLATED:897fa4ae] + // If the matching fails due to profile check, retry once and do not check profile + check_profile = false; + goto RETRY; + } + + if (check_codec) { + // 如果是由于检查codec导致匹配失败,那么重试一次,且不检查codec [AUTO-TRANSLATED:fbd85968] + // If the matching fails due to codec check, retry once and do not check codec + check_codec = false; + goto RETRY; + } +} + +void RtcConfigure::setPlayRtspInfo(const string &sdp) { + RtcSession session; + video.direction = RtpDirection::inactive; + audio.direction = RtpDirection::inactive; + + session.loadFrom(sdp); + for (auto &m : session.media) { + switch (m.type) { + case TrackVideo: { + video.direction = RtpDirection::sendonly; + _rtsp_video_plan = std::make_shared(m.plan[0]); + video.preferred_codec.clear(); + video.preferred_codec.emplace_back(getCodecId(_rtsp_video_plan->codec)); + break; + } + case TrackAudio: { + audio.direction = RtpDirection::sendonly; + _rtsp_audio_plan = std::make_shared(m.plan[0]); + audio.preferred_codec.clear(); + audio.preferred_codec.emplace_back(getCodecId(_rtsp_audio_plan->codec)); + break; + } + default: break; + } + } +} + +static const string kH264Profile { "profile-level-id" }; +static const string kH265Profile { "profile-id" }; +static const string kMode { "packetization-mode" }; + +bool RtcConfigure::onCheckCodecProfile(const RtcCodecPlan &plan, CodecId codec) const { + if (_rtsp_audio_plan && codec == getCodecId(_rtsp_audio_plan->codec)) { + if (plan.sample_rate != _rtsp_audio_plan->sample_rate || plan.channel != _rtsp_audio_plan->channel) { + // 音频采样率和通道数必须相同 [AUTO-TRANSLATED:6e591932] + // Audio sampling rate and number of channels must be the same + return false; + } + return true; + } + if (_rtsp_video_plan && codec == CodecH264 && getCodecId(_rtsp_video_plan->codec) == CodecH264) { + // h264时,profile-level-id [AUTO-TRANSLATED:94a5f360] + // When h264, profile-level-id + if (strcasecmp(_rtsp_video_plan->fmtp[kH264Profile].data(), const_cast(plan).fmtp[kH264Profile].data())) { + // profile-level-id 不匹配 [AUTO-TRANSLATED:814ec4c4] + // profile-level-id does not match + return false; + } + return true; + } + + if (_rtsp_video_plan && codec == CodecH265 && getCodecId(_rtsp_video_plan->codec) == CodecH265) { + // h265时,profile-id + if (strcasecmp(_rtsp_video_plan->fmtp[kH265Profile].data(), const_cast(plan).fmtp[kH265Profile].data())) { + // profile-id 不匹配 + return false; + } + return true; + } + + return true; +} + +/** + Single NAI Unit Mode = 0. // Single NAI mode (Only nals from 1-23 are allowed) + Non Interleaved Mode = 1,// Non-interleaved Mode: 1-23,24 (STAP-A),28 (FU-A) are allowed + Interleaved Mode = 2, // 25 (STAP-B),26 (MTAP16),27 (MTAP24),28 (EU-A),and 29 (EU-B) are allowed. + Single NAI Unit Mode = 0. // Single NAI mode (Only nals from 1-23 are allowed) + Non Interleaved Mode = 1,// Non-interleaved Mode: 1-23,24 (STAP-A),28 (FU-A) are allowed + Interleaved Mode = 2, // 25 (STAP-B),26 (MTAP16),27 (MTAP24),28 (EU-A),and 29 (EU-B) are allowed. + * + * [AUTO-TRANSLATED:b1526114] + **/ +void RtcConfigure::onSelectPlan(RtcCodecPlan &plan, CodecId codec) const { + if (_rtsp_video_plan && codec == CodecH264 && getCodecId(_rtsp_video_plan->codec) == CodecH264) { + // h264时,设置packetization-mod为一致 [AUTO-TRANSLATED:59a00889] + // When h264, set packetization-mod to be consistent + auto mode = _rtsp_video_plan->fmtp[kMode]; + GET_CONFIG(bool, h264_stap_a, Rtp::kH264StapA); + plan.fmtp[kMode] = mode.empty() ? std::to_string(h264_stap_a) : mode; + } +} + +} // namespace mediakit diff --git a/webrtc/Sdp.h b/webrtc/Sdp.h index b260f852..8139f645 100644 --- a/webrtc/Sdp.h +++ b/webrtc/Sdp.h @@ -1,776 +1,779 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_SDP_H -#define ZLMEDIAKIT_SDP_H - -#include -#include -#include -#include -#include "RtpExt.h" -#include "assert.h" -#include "Extension/Frame.h" -#include "Common/Parser.h" - -namespace mediakit { - -// https://datatracker.ietf.org/doc/rfc4566/?include_text=1 -// https://blog.csdn.net/aggresss/article/details/109850434 -// https://aggresss.blog.csdn.net/article/details/106436703 -// Session description -// v= (protocol version) -// o= (originator and session identifier) -// s= (session name) -// i=* (session information) -// u=* (URI of description) -// e=* (email address) -// p=* (phone number) -// c=* (connection information -- not required if included in -// all media) -// b=* (zero or more bandwidth information lines) -// One or more time descriptions ("t=" and "r=" lines; see below) -// z=* (time zone adjustments) -// k=* (encryption key) -// a=* (zero or more session attribute lines) -// Zero or more media descriptions -// -// Time description -// t= (time the session is active) -// r=* (zero or more repeat times) -// -// Media description, if present -// m= (media name and transport address) -// i=* (media title) -// c=* (connection information -- optional if included at -// session level) -// b=* (zero or more bandwidth information lines) -// k=* (encryption key) -// a=* (zero or more media attribute lines) - -enum class RtpDirection { - invalid = -1, - // 只发送 [AUTO-TRANSLATED:d7e7fdb7] - // Send only - sendonly, - // 只接收 [AUTO-TRANSLATED:f75ca789] - // Receive only - recvonly, - // 同时发送接收 [AUTO-TRANSLATED:7f900ba1] - // Send and receive simultaneously - sendrecv, - // 禁止发送数据 [AUTO-TRANSLATED:6045b47e] - // Prohibit sending data - inactive -}; - -enum class DtlsRole { - invalid = -1, - // 客户端 [AUTO-TRANSLATED:915417a2] - // Client - active, - // 服务端 [AUTO-TRANSLATED:03a80b18] - // Server - passive, - // 既可作做客户端也可以做服务端 [AUTO-TRANSLATED:5ab1162e] - // Can be used as both client and server - actpass, -}; - -enum class SdpType { invalid = -1, offer, answer }; - -DtlsRole getDtlsRole(const std::string &str); -const char *getDtlsRoleString(DtlsRole role); -RtpDirection getRtpDirection(const std::string &str); -const char *getRtpDirectionString(RtpDirection val); - -class SdpItem { -public: - using Ptr = std::shared_ptr; - virtual ~SdpItem() = default; - virtual void parse(const std::string &str) { value = str; } - virtual std::string toString() const { return value; } - virtual const char *getKey() const = 0; - - void reset() { value.clear(); } - -protected: - mutable std::string value; -}; - -template -class SdpString : public SdpItem { -public: - SdpString() = default; - SdpString(std::string val) { value = std::move(val); } - // *=* - const char* getKey() const override { static std::string key(1, KEY); return key.data();} -}; - -class SdpCommon : public SdpItem { -public: - std::string key; - SdpCommon(std::string key) { this->key = std::move(key); } - SdpCommon(std::string key, std::string val) { - this->key = std::move(key); - this->value = std::move(val); - } - - const char *getKey() const override { return key.data(); } -}; - -class SdpTime : public SdpItem { -public: - // 5.9. Timing ("t=") - // t= - uint64_t start { 0 }; - uint64_t stop { 0 }; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "t"; } -}; - -class SdpOrigin : public SdpItem { -public: - // 5.2. Origin ("o=") - // o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5 - // o= - std::string username { "-" }; - std::string session_id; - std::string session_version; - std::string nettype { "IN" }; - std::string addrtype { "IP4" }; - std::string address { "0.0.0.0" }; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "o"; } - bool empty() const { - return username.empty() || session_id.empty() || session_version.empty() - || nettype.empty() || addrtype.empty() || address.empty(); - } -}; - -class SdpConnection : public SdpItem { -public: - // 5.7. Connection Data ("c=") - // c=IN IP4 224.2.17.12/127 - // c= - std::string nettype { "IN" }; - std::string addrtype { "IP4" }; - std::string address { "0.0.0.0" }; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "c"; } - bool empty() const { return address.empty(); } -}; - -class SdpBandwidth : public SdpItem { -public: - // 5.8. Bandwidth ("b=") - // b=: - - // AS、CT [AUTO-TRANSLATED:65298206] - // AS, CT - std::string bwtype { "AS" }; - uint32_t bandwidth { 0 }; - - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "b"; } - bool empty() const { return bandwidth == 0; } -}; - -class SdpMedia : public SdpItem { -public: - // 5.14. Media Descriptions ("m=") - // m= ... - TrackType type; - uint16_t port; - // RTP/AVP:应用场景为视频/音频的 RTP 协议。参考 RFC 3551 [AUTO-TRANSLATED:7a9d7e86] - // RTP/AVP: The application scenario is the RTP protocol for video/audio. Refer to RFC 3551 - // RTP/SAVP:应用场景为视频/音频的 SRTP 协议。参考 RFC 3711 [AUTO-TRANSLATED:7989a619] - // RTP/SAVP: The application scenario is the SRTP protocol for video/audio. Refer to RFC 3711 - // RTP/AVPF: 应用场景为视频/音频的 RTP 协议,支持 RTCP-based Feedback。参考 RFC 4585 [AUTO-TRANSLATED:71241e80] - // RTP/AVPF: The application scenario is the RTP protocol for video/audio, supporting RTCP-based Feedback. Refer to RFC 4585 - // RTP/SAVPF: 应用场景为视频/音频的 SRTP 协议,支持 RTCP-based Feedback。参考 RFC 5124 [AUTO-TRANSLATED:69015267] - // RTP/SAVPF: The application scenario is the SRTP protocol for video/audio, supporting RTCP-based Feedback. Refer to RFC 5124 - std::string proto; - std::vector fmts; - - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "m"; } -}; - -class SdpAttr : public SdpItem { -public: - using Ptr = std::shared_ptr; - // 5.13. Attributes ("a=") - // a= - // a=: - SdpItem::Ptr detail; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "a"; } -}; - -class SdpAttrGroup : public SdpItem { -public: - // a=group:BUNDLE line with all the 'mid' identifiers part of the - // BUNDLE group is included at the session-level. - // a=group:LS session level attribute MUST be included wth the 'mid' - // identifiers that are part of the same lip sync group. - std::string type { "BUNDLE" }; - std::vector mids; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "group"; } -}; - -class SdpAttrMsidSemantic : public SdpItem { -public: - // https://tools.ietf.org/html/draft-alvestrand-rtcweb-msid-02#section-3 - // 3. The Msid-Semantic Attribute - // - // In order to fully reproduce the semantics of the SDP and SSRC - // grouping frameworks, a session-level attribute is defined for - // signalling the semantics associated with an msid grouping. - // - // This OPTIONAL attribute gives the message ID and its group semantic. - // a=msid-semantic: examplefoo LS - // - // - // The ABNF of msid-semantic is: - // - // msid-semantic-attr = "msid-semantic:" " " msid token - // token = - // - // The semantic field may hold values from the IANA registries - // "Semantics for the "ssrc-group" SDP Attribute" and "Semantics for the - // "group" SDP Attribute". - // a=msid-semantic: WMS 616cfbb1-33a3-4d8c-8275-a199d6005549 - std::string msid { "WMS" }; - std::string token; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "msid-semantic"; } - bool empty() const { return msid.empty(); } -}; - -class SdpAttrRtcp : public SdpItem { -public: - // a=rtcp:9 IN IP4 0.0.0.0 - uint16_t port { 0 }; - std::string nettype { "IN" }; - std::string addrtype { "IP4" }; - std::string address { "0.0.0.0" }; - void parse(const std::string &str) override; - ; - std::string toString() const override; - const char *getKey() const override { return "rtcp"; } - bool empty() const { return address.empty() || !port; } -}; - -class SdpAttrIceUfrag : public SdpItem { -public: - SdpAttrIceUfrag() = default; - SdpAttrIceUfrag(std::string str) { value = std::move(str); } - // a=ice-ufrag:sXJ3 - const char *getKey() const override { return "ice-ufrag"; } -}; - -class SdpAttrIcePwd : public SdpItem { -public: - SdpAttrIcePwd() = default; - SdpAttrIcePwd(std::string str) { value = std::move(str); } - // a=ice-pwd:yEclOTrLg1gEubBFefOqtmyV - const char *getKey() const override { return "ice-pwd"; } -}; - -class SdpAttrIceOption : public SdpItem { -public: - // a=ice-options:trickle - bool trickle { false }; - bool renomination { false }; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "ice-options"; } -}; - -class SdpAttrFingerprint : public SdpItem { -public: - // a=fingerprint:sha-256 22:14:B5:AF:66:12:C7:C7:8D:EF:4B:DE:40:25:ED:5D:8F:17:54:DD:88:33:C0:13:2E:FD:1A:FA:7E:7A:1B:79 - std::string algorithm; - std::string hash; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "fingerprint"; } - bool empty() const { return algorithm.empty() || hash.empty(); } -}; - -class SdpAttrSetup : public SdpItem { -public: - // a=setup:actpass - SdpAttrSetup() = default; - SdpAttrSetup(DtlsRole r) { role = r; } - DtlsRole role { DtlsRole::actpass }; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "setup"; } -}; - -class SdpAttrMid : public SdpItem { -public: - SdpAttrMid() = default; - SdpAttrMid(std::string val) { value = std::move(val); } - // a=mid:audio - const char *getKey() const override { return "mid"; } -}; - -class SdpAttrExtmap : public SdpItem { -public: - // https://aggresss.blog.csdn.net/article/details/106436703 - // a=extmap:1[/sendonly] urn:ietf:params:rtp-hdrext:ssrc-audio-level - uint8_t id; - RtpDirection direction { RtpDirection::invalid }; - std::string ext; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "extmap"; } -}; - -class SdpAttrRtpMap : public SdpItem { -public: - // a=rtpmap:111 opus/48000/2 - uint8_t pt; - std::string codec; - uint32_t sample_rate; - uint32_t channel { 0 }; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "rtpmap"; } -}; - -class SdpAttrRtcpFb : public SdpItem { -public: - // a=rtcp-fb:98 nack pli - // a=rtcp-fb:120 nack 支持 nack 重传,nack (Negative-Acknowledgment) 。 [AUTO-TRANSLATED:08d5c4e2] - // a=rtcp-fb:120 nack supports nack retransmission, nack (Negative-Acknowledgment). - // a=rtcp-fb:120 nack pli 支持 nack 关键帧重传,PLI (Picture Loss Indication) 。 [AUTO-TRANSLATED:c331c1dd] - // a=rtcp-fb:120 nack pli supports nack keyframe retransmission, PLI (Picture Loss Indication). - // a=rtcp-fb:120 ccm fir 支持编码层关键帧请求,CCM (Codec Control Message),FIR (Full Intra Request ),通常与 nack pli 有同样的效果,但是 nack pli [AUTO-TRANSLATED:7090fdc9] - // a=rtcp-fb:120 ccm fir supports keyframe requests for the coding layer, CCM (Codec Control Message), FIR (Full Intra Request), which usually has the same effect as nack pli, but nack pli - // 是用于重传时的关键帧请求。 a=rtcp-fb:120 goog-remb 支持 REMB (Receiver Estimated Maximum Bitrate) 。 a=rtcp-fb:120 transport-cc 支持 TCC (Transport [AUTO-TRANSLATED:ffac8e91] - // is used for keyframe requests during retransmission. a=rtcp-fb:120 goog-remb supports REMB (Receiver Estimated Maximum Bitrate). a=rtcp-fb:120 transport-cc supports TCC (Transport - // Congest Control) 。 [AUTO-TRANSLATED:dcf53e31] - // Congest Control). - uint8_t pt; - std::string rtcp_type; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "rtcp-fb"; } -}; - -class SdpAttrFmtp : public SdpItem { -public: - // fmtp:96 level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f - uint8_t pt; - std::map fmtp; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "fmtp"; } -}; - -class SdpAttrSSRC : public SdpItem { -public: - // a=ssrc:3245185839 cname:Cx4i/VTR51etgjT7 - // a=ssrc:3245185839 msid:cb373bff-0fea-4edb-bc39-e49bb8e8e3b9 0cf7e597-36a2-4480-9796-69bf0955eef5 - // a=ssrc:3245185839 mslabel:cb373bff-0fea-4edb-bc39-e49bb8e8e3b9 - // a=ssrc:3245185839 label:0cf7e597-36a2-4480-9796-69bf0955eef5 - // a=ssrc: - // a=ssrc: : - // cname 是必须的,msid/mslabel/label 这三个属性都是 WebRTC 自创的,或者说 Google 自创的,可以参考 https://tools.ietf.org/html/draft-ietf-mmusic-msid-17, [AUTO-TRANSLATED:d8cb1baf] - // cname is required, msid/mslabel/label these three attributes are all created by WebRTC, or Google created, you can refer to https://tools.ietf.org/html/draft-ietf-mmusic-msid-17, - // 理解它们三者的关系需要先了解三个概念:RTP stream / MediaStreamTrack / MediaStream : [AUTO-TRANSLATED:7d385cf5] - // understanding the relationship between the three requires understanding three concepts: RTP stream / MediaStreamTrack / MediaStream: - // 一个 a=ssrc 代表一个 RTP stream ; [AUTO-TRANSLATED:ee1ecc6f] - // One a=ssrc represents one RTP stream; - // 一个 MediaStreamTrack 通常包含一个或多个 RTP stream,例如一个视频 MediaStreamTrack 中通常包含两个 RTP stream,一个用于常规传输,一个用于 nack 重传; [AUTO-TRANSLATED:e8ddf0fd] - // A MediaStreamTrack usually contains one or more RTP streams, for example, a video MediaStreamTrack usually contains two RTP streams, one for regular transmission and one for nack retransmission; - // 一个 MediaStream 通常包含一个或多个 MediaStreamTrack ,例如 simulcast 场景下,一个 MediaStream 通常会包含三个不同编码质量的 MediaStreamTrack ; [AUTO-TRANSLATED:31318d43] - // A MediaStream usually contains one or more MediaStreamTrack, for example, in a simulcast scenario, a MediaStream usually contains three MediaStreamTrack of different encoding quality; - // 这种标记方式并不被 Firefox 认可,在 Firefox 生成的 SDP 中一个 a=ssrc 通常只有一行,例如: [AUTO-TRANSLATED:8c2c424c] - // This marking method is not recognized by Firefox, in the SDP generated by Firefox, one a=ssrc usually has only one line, for example: - // a=ssrc:3245185839 cname:Cx4i/VTR51etgjT7 - - uint32_t ssrc; - std::string attribute; - std::string attribute_value; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "ssrc"; } -}; - -class SdpAttrSSRCGroup : public SdpItem { -public: - // a=ssrc-group 定义参考 RFC 5576(https://tools.ietf.org/html/rfc5576) ,用于描述多个 ssrc 之间的关联,常见的有两种: [AUTO-TRANSLATED:a87cbcc6] - // a=ssrc-group definition refers to RFC 5576(https://tools.ietf.org/html/rfc5576), used to describe the association between multiple ssrcs, there are two common types: - // a=ssrc-group:FID 2430709021 3715850271 - // FID (Flow Identification) 最初用在 FEC 的关联中,WebRTC 中通常用于关联一组常规 RTP stream 和 重传 RTP stream 。 [AUTO-TRANSLATED:f2c0fcbb] - // FID (Flow Identification) was originally used in FEC association, and in WebRTC it is usually used to associate a group of regular RTP streams and retransmission RTP streams. - // a=ssrc-group:SIM 360918977 360918978 360918980 - // 在 Chrome 独有的 SDP munging 风格的 simulcast 中使用,将三组编码质量由低到高的 MediaStreamTrack 关联在一起。 [AUTO-TRANSLATED:61bf7596] - // Used in Chrome's unique SDP munging style simulcast, associating three groups of MediaStreamTrack from low to high encoding quality. - std::string type { "FID" }; - std::vector ssrcs; - - bool isFID() const { return type == "FID"; } - bool isSIM() const { return type == "SIM"; } - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "ssrc-group"; } -}; - -class SdpAttrSctpMap : public SdpItem { -public: - // https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-05 - // a=sctpmap:5000 webrtc-datachannel 1024 - // a=sctpmap: sctpmap-number media-subtypes [streams] - uint16_t port = 0; - std::string subtypes; - uint32_t streams = 0; - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "sctpmap"; } - bool empty() const { return port == 0 && subtypes.empty() && streams == 0; } -}; - -class SdpAttrCandidate : public SdpItem { -public: - using Ptr = std::shared_ptr; - // https://tools.ietf.org/html/rfc5245 - // 15.1. "candidate" Attribute - // a=candidate:4 1 udp 2 192.168.1.7 58107 typ host - // a=candidate:
typ - std::string foundation; - // 传输媒体的类型,1代表RTP;2代表 RTCP。 [AUTO-TRANSLATED:9ec924a6] - // The type of media to be transmitted, 1 represents RTP; 2 represents RTCP. - uint32_t component; - std::string transport { "udp" }; - uint32_t priority; - std::string address; - uint16_t port; - std::string type; - std::vector> arr; - - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "candidate"; } -}; - -class SdpAttrMsid : public SdpItem { -public: - const char *getKey() const override { return "msid"; } -}; - -class SdpAttrExtmapAllowMixed : public SdpItem { -public: - const char *getKey() const override { return "extmap-allow-mixed"; } -}; - -class SdpAttrSimulcast : public SdpItem { -public: - // https://www.meetecho.com/blog/simulcast-janus-ssrc/ - // https://tools.ietf.org/html/draft-ietf-mmusic-sdp-simulcast-14 - const char *getKey() const override { return "simulcast"; } - void parse(const std::string &str) override; - std::string toString() const override; - bool empty() const { return rids.empty(); } - std::string direction; - std::vector rids; -}; - -class SdpAttrRid : public SdpItem { -public: - void parse(const std::string &str) override; - std::string toString() const override; - const char *getKey() const override { return "rid"; } - std::string direction; - std::string rid; -}; - -class RtcSdpBase { -public: - void addItem(SdpItem::Ptr item) { items.push_back(std::move(item)); } - void addAttr(SdpItem::Ptr attr) { - auto item = std::make_shared(); - item->detail = std::move(attr); - items.push_back(std::move(item)); - } - - virtual ~RtcSdpBase() = default; - virtual std::string toString() const; - void toRtsp(); - - RtpDirection getDirection() const; - - template - cls getItemClass(char key, const char *attr_key = nullptr) const { - auto item = std::dynamic_pointer_cast(getItem(key, attr_key)); - if (!item) { - return cls(); - } - return *item; - } - - std::string getStringItem(char key, const char *attr_key = nullptr) const { - auto item = getItem(key, attr_key); - if (!item) { - return ""; - } - return item->toString(); - } - - SdpItem::Ptr getItem(char key, const char *attr_key = nullptr) const; - - template - std::vector getAllItem(char key_c, const char *attr_key = nullptr) const { - std::vector ret; - std::string key(1, key_c); - for (auto item : items) { - if (strcasecmp(item->getKey(), key.data()) == 0) { - if (!attr_key) { - auto c = std::dynamic_pointer_cast(item); - if (c) { - ret.emplace_back(*c); - } - } else { - auto attr = std::dynamic_pointer_cast(item); - if (attr && !strcasecmp(attr->detail->getKey(), attr_key)) { - auto c = std::dynamic_pointer_cast(attr->detail); - if (c) { - ret.emplace_back(*c); - } - } - } - } - } - return ret; - } - -private: - std::vector items; -}; - -class RtcSessionSdp : public RtcSdpBase { -public: - using Ptr = std::shared_ptr; - int getVersion() const; - SdpOrigin getOrigin() const; - std::string getSessionName() const; - std::string getSessionInfo() const; - SdpTime getSessionTime() const; - SdpConnection getConnection() const; - SdpBandwidth getBandwidth() const; - - std::string getUri() const; - std::string getEmail() const; - std::string getPhone() const; - std::string getTimeZone() const; - std::string getEncryptKey() const; - std::string getRepeatTimes() const; - - std::vector medias; - void parse(const std::string &str); - std::string toString() const override; -}; - -////////////////////////////////////////////////////////////////// - -// ssrc相关信息 [AUTO-TRANSLATED:954c641d] -// ssrc related information -class RtcSSRC { -public: - uint32_t ssrc { 0 }; - uint32_t rtx_ssrc { 0 }; - std::string cname; - std::string msid; - std::string mslabel; - std::string label; - - bool empty() const { return ssrc == 0 && cname.empty(); } -}; - -// rtc传输编码方案 [AUTO-TRANSLATED:8b911508] -// rtc transmission encoding scheme -class RtcCodecPlan { -public: - using Ptr = std::shared_ptr; - uint8_t pt; - std::string codec; - uint32_t sample_rate; - // 音频时有效 [AUTO-TRANSLATED:5b230fc8] - // Valid for audio - uint32_t channel = 0; - // rtcp反馈 [AUTO-TRANSLATED:580378bd] - // RTCP feedback - std::set rtcp_fb; - std::map fmtp; - - std::string getFmtp(const char *key) const; -}; - -// rtc 媒体描述 [AUTO-TRANSLATED:b1711a11] -// RTC media description -class RtcMedia { -public: - TrackType type { TrackType::TrackInvalid }; - std::string mid; - uint16_t port { 0 }; - SdpConnection addr; - SdpBandwidth bandwidth; - std::string proto; - RtpDirection direction { RtpDirection::invalid }; - std::vector plan; - - //////// rtp //////// - std::vector rtp_rtx_ssrc; - - //////// simulcast //////// - std::vector rtp_ssrc_sim; - std::vector rtp_rids; - - //////// rtcp //////// - bool rtcp_mux { false }; - bool rtcp_rsize { false }; - SdpAttrRtcp rtcp_addr; - - //////// ice //////// - bool ice_trickle { false }; - bool ice_lite { false }; - bool ice_renomination { false }; - std::string ice_ufrag; - std::string ice_pwd; - std::vector candidate; - - //////// dtls //////// - DtlsRole role { DtlsRole::invalid }; - SdpAttrFingerprint fingerprint; - - //////// extmap //////// - std::vector extmap; - - //////// sctp //////////// - SdpAttrSctpMap sctpmap; - uint32_t sctp_port { 0 }; - - void checkValid() const; - const RtcCodecPlan *getPlan(uint8_t pt) const; - const RtcCodecPlan *getPlan(const char *codec) const; - const RtcCodecPlan *getRelatedRtxPlan(uint8_t pt) const; - uint32_t getRtpSSRC() const; - uint32_t getRtxSSRC() const; - bool supportSimulcast() const; -}; - -class RtcSession { -public: - using Ptr = std::shared_ptr; - - uint32_t version; - SdpOrigin origin; - std::string session_name; - std::string session_info; - SdpTime time; - SdpConnection connection; - SdpAttrMsidSemantic msid_semantic; - std::vector media; - SdpAttrGroup group; - - void loadFrom(const std::string &sdp); - void checkValid() const; - std::string toString() const; - std::string toRtspSdp() const; - const RtcMedia *getMedia(TrackType type) const; - bool supportRtcpFb(const std::string &name, TrackType type = TrackType::TrackVideo) const; - bool supportSimulcast() const; - bool isOnlyDatachannel() const; - -private: - RtcSessionSdp::Ptr toRtcSessionSdp() const; -}; - -class RtcConfigure { -public: - using Ptr = std::shared_ptr; - class RtcTrackConfigure { - public: - bool rtcp_mux; - bool rtcp_rsize; - bool group_bundle; - bool support_rtx; - bool support_red; - bool support_ulpfec; - bool ice_lite; - bool ice_trickle; - bool ice_renomination; - std::string ice_ufrag; - std::string ice_pwd; - - RtpDirection direction { RtpDirection::invalid }; - SdpAttrFingerprint fingerprint; - - std::set rtcp_fb; - std::map extmap; - std::vector preferred_codec; - std::vector candidate; - - void setDefaultSetting(TrackType type); - void enableTWCC(bool enable = true); - void enableREMB(bool enable = true); - }; - - RtcTrackConfigure video; - RtcTrackConfigure audio; - RtcTrackConfigure application; - - void setDefaultSetting(std::string ice_ufrag, std::string ice_pwd, RtpDirection direction, const SdpAttrFingerprint &fingerprint); - void addCandidate(const SdpAttrCandidate &candidate, TrackType type = TrackInvalid); - - std::shared_ptr createAnswer(const RtcSession &offer) const; - - void setPlayRtspInfo(const std::string &sdp); - - void enableTWCC(bool enable = true, TrackType type = TrackInvalid); - void enableREMB(bool enable = true, TrackType type = TrackInvalid); - -private: - void matchMedia(const std::shared_ptr &ret, const RtcMedia &media) const; - bool onCheckCodecProfile(const RtcCodecPlan &plan, CodecId codec) const; - void onSelectPlan(RtcCodecPlan &plan, CodecId codec) const; - -private: - RtcCodecPlan::Ptr _rtsp_video_plan; - RtcCodecPlan::Ptr _rtsp_audio_plan; -}; - -class SdpConst { -public: - static std::string const kTWCCRtcpFb; - static std::string const kRembRtcpFb; - -private: - SdpConst() = delete; - ~SdpConst() = delete; -}; - -} // namespace mediakit - -#endif // ZLMEDIAKIT_SDP_H +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_SDP_H +#define ZLMEDIAKIT_SDP_H + +#include +#include +#include +#include +#include "RtpExt.h" +#include "RtpMap.h" +#include "Extension/Frame.h" +#include "Common/Parser.h" + +namespace mediakit { + +// https://datatracker.ietf.org/doc/rfc4566/?include_text=1 +// https://blog.csdn.net/aggresss/article/details/109850434 +// https://aggresss.blog.csdn.net/article/details/106436703 +// Session description +// v= (protocol version) +// o= (originator and session identifier) +// s= (session name) +// i=* (session information) +// u=* (URI of description) +// e=* (email address) +// p=* (phone number) +// c=* (connection information -- not required if included in +// all media) +// b=* (zero or more bandwidth information lines) +// One or more time descriptions ("t=" and "r=" lines; see below) +// z=* (time zone adjustments) +// k=* (encryption key) +// a=* (zero or more session attribute lines) +// Zero or more media descriptions +// +// Time description +// t= (time the session is active) +// r=* (zero or more repeat times) +// +// Media description, if present +// m= (media name and transport address) +// i=* (media title) +// c=* (connection information -- optional if included at +// session level) +// b=* (zero or more bandwidth information lines) +// k=* (encryption key) +// a=* (zero or more media attribute lines) + +enum class RtpDirection : int8_t { + invalid = -1, + // 只发送 [AUTO-TRANSLATED:d7e7fdb7] + // Send only + sendonly = 1 << 0, + // 只接收 [AUTO-TRANSLATED:f75ca789] + // Receive only + recvonly = 1 << 1, + // 同时发送接收 [AUTO-TRANSLATED:7f900ba1] + // Send and receive simultaneously + sendrecv = sendonly | recvonly, + // 禁止发送数据 [AUTO-TRANSLATED:6045b47e] + // Prohibit sending data + inactive = 0 +}; + +enum class DtlsRole : int8_t { + invalid = -1, + // 客户端 [AUTO-TRANSLATED:915417a2] + // Client + active = 1 << 0, + // 服务端 [AUTO-TRANSLATED:03a80b18] + // Server + passive = 1 << 1, + // 既可作做客户端也可以做服务端 [AUTO-TRANSLATED:5ab1162e] + // Can be used as both client and server + actpass = active | passive, +}; + +enum class SdpType : int8_t { invalid = -1, offer, answer }; + +DtlsRole getDtlsRole(const std::string &str); +const char *getDtlsRoleString(DtlsRole role); +RtpDirection getRtpDirection(const std::string &str); +const char *getRtpDirectionString(RtpDirection val); + +class SdpItem { +public: + using Ptr = std::shared_ptr; + virtual ~SdpItem() = default; + virtual void parse(const std::string &str) { value = str; } + virtual std::string toString() const { return value; } + virtual const char *getKey() const = 0; + + void reset() { value.clear(); } + +protected: + mutable std::string value; +}; + +template +class SdpString : public SdpItem { +public: + SdpString() = default; + SdpString(std::string val) { value = std::move(val); } + // *=* + const char* getKey() const override { static std::string key(1, KEY); return key.data();} +}; + +class SdpCommon : public SdpItem { +public: + std::string key; + SdpCommon(std::string key) { this->key = std::move(key); } + SdpCommon(std::string key, std::string val) { + this->key = std::move(key); + this->value = std::move(val); + } + + const char *getKey() const override { return key.data(); } +}; + +class SdpTime : public SdpItem { +public: + // 5.9. Timing ("t=") + // t= + uint64_t start { 0 }; + uint64_t stop { 0 }; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "t"; } +}; + +class SdpOrigin : public SdpItem { +public: + // 5.2. Origin ("o=") + // o=jdoe 2890844526 2890842807 IN IP4 10.47.16.5 + // o= + std::string username { "-" }; + std::string session_id; + std::string session_version; + std::string nettype { "IN" }; + std::string addrtype { "IP4" }; + std::string address { "0.0.0.0" }; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "o"; } + bool empty() const { + return username.empty() || session_id.empty() || session_version.empty() + || nettype.empty() || addrtype.empty() || address.empty(); + } +}; + +class SdpConnection : public SdpItem { +public: + // 5.7. Connection Data ("c=") + // c=IN IP4 224.2.17.12/127 + // c= + std::string nettype { "IN" }; + std::string addrtype { "IP4" }; + std::string address { "0.0.0.0" }; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "c"; } + bool empty() const { return address.empty(); } +}; + +class SdpBandwidth : public SdpItem { +public: + // 5.8. Bandwidth ("b=") + // b=: + + // AS、CT [AUTO-TRANSLATED:65298206] + // AS, CT + std::string bwtype { "AS" }; + uint32_t bandwidth { 0 }; + + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "b"; } + bool empty() const { return bandwidth == 0; } +}; + +class SdpMedia : public SdpItem { +public: + // 5.14. Media Descriptions ("m=") + // m= ... + TrackType type; + uint16_t port; + // RTP/AVP:应用场景为视频/音频的 RTP 协议。参考 RFC 3551 [AUTO-TRANSLATED:7a9d7e86] + // RTP/AVP: The application scenario is the RTP protocol for video/audio. Refer to RFC 3551 + // RTP/SAVP:应用场景为视频/音频的 SRTP 协议。参考 RFC 3711 [AUTO-TRANSLATED:7989a619] + // RTP/SAVP: The application scenario is the SRTP protocol for video/audio. Refer to RFC 3711 + // RTP/AVPF: 应用场景为视频/音频的 RTP 协议,支持 RTCP-based Feedback。参考 RFC 4585 [AUTO-TRANSLATED:71241e80] + // RTP/AVPF: The application scenario is the RTP protocol for video/audio, supporting RTCP-based Feedback. Refer to RFC 4585 + // RTP/SAVPF: 应用场景为视频/音频的 SRTP 协议,支持 RTCP-based Feedback。参考 RFC 5124 [AUTO-TRANSLATED:69015267] + // RTP/SAVPF: The application scenario is the SRTP protocol for video/audio, supporting RTCP-based Feedback. Refer to RFC 5124 + std::string proto; + std::vector fmts; + + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "m"; } +}; + +class SdpAttr : public SdpItem { +public: + using Ptr = std::shared_ptr; + // 5.13. Attributes ("a=") + // a= + // a=: + SdpItem::Ptr detail; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "a"; } +}; + +class SdpAttrGroup : public SdpItem { +public: + // a=group:BUNDLE line with all the 'mid' identifiers part of the + // BUNDLE group is included at the session-level. + // a=group:LS session level attribute MUST be included wth the 'mid' + // identifiers that are part of the same lip sync group. + std::string type { "BUNDLE" }; + std::vector mids; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "group"; } +}; + +class SdpAttrMsidSemantic : public SdpItem { +public: + // https://tools.ietf.org/html/draft-alvestrand-rtcweb-msid-02#section-3 + // 3. The Msid-Semantic Attribute + // + // In order to fully reproduce the semantics of the SDP and SSRC + // grouping frameworks, a session-level attribute is defined for + // signalling the semantics associated with an msid grouping. + // + // This OPTIONAL attribute gives the message ID and its group semantic. + // a=msid-semantic: examplefoo LS + // + // + // The ABNF of msid-semantic is: + // + // msid-semantic-attr = "msid-semantic:" " " msid token + // token = + // + // The semantic field may hold values from the IANA registries + // "Semantics for the "ssrc-group" SDP Attribute" and "Semantics for the + // "group" SDP Attribute". + // a=msid-semantic: WMS 616cfbb1-33a3-4d8c-8275-a199d6005549 + std::string msid { "WMS" }; + std::string token; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "msid-semantic"; } + bool empty() const { return msid.empty(); } +}; + +class SdpAttrRtcp : public SdpItem { +public: + // a=rtcp:9 IN IP4 0.0.0.0 + uint16_t port { 0 }; + std::string nettype { "IN" }; + std::string addrtype { "IP4" }; + std::string address { "0.0.0.0" }; + void parse(const std::string &str) override; + ; + std::string toString() const override; + const char *getKey() const override { return "rtcp"; } + bool empty() const { return address.empty() || !port; } +}; + +class SdpAttrIceUfrag : public SdpItem { +public: + SdpAttrIceUfrag() = default; + SdpAttrIceUfrag(std::string str) { value = std::move(str); } + // a=ice-ufrag:sXJ3 + const char *getKey() const override { return "ice-ufrag"; } +}; + +class SdpAttrIcePwd : public SdpItem { +public: + SdpAttrIcePwd() = default; + SdpAttrIcePwd(std::string str) { value = std::move(str); } + // a=ice-pwd:yEclOTrLg1gEubBFefOqtmyV + const char *getKey() const override { return "ice-pwd"; } +}; + +class SdpAttrIceOption : public SdpItem { +public: + // a=ice-options:trickle + bool trickle { false }; + bool renomination { false }; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "ice-options"; } +}; + +class SdpAttrFingerprint : public SdpItem { +public: + // a=fingerprint:sha-256 22:14:B5:AF:66:12:C7:C7:8D:EF:4B:DE:40:25:ED:5D:8F:17:54:DD:88:33:C0:13:2E:FD:1A:FA:7E:7A:1B:79 + std::string algorithm; + std::string hash; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "fingerprint"; } + bool empty() const { return algorithm.empty() || hash.empty(); } +}; + +class SdpAttrSetup : public SdpItem { +public: + // a=setup:actpass + SdpAttrSetup() = default; + SdpAttrSetup(DtlsRole r) { role = r; } + DtlsRole role { DtlsRole::actpass }; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "setup"; } +}; + +class SdpAttrMid : public SdpItem { +public: + SdpAttrMid() = default; + SdpAttrMid(std::string val) { value = std::move(val); } + // a=mid:audio + const char *getKey() const override { return "mid"; } +}; + +class SdpAttrExtmap : public SdpItem { +public: + // https://aggresss.blog.csdn.net/article/details/106436703 + // a=extmap:1[/sendonly] urn:ietf:params:rtp-hdrext:ssrc-audio-level + uint8_t id; + RtpDirection direction { RtpDirection::invalid }; + std::string ext; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "extmap"; } +}; + +class SdpAttrRtpMap : public SdpItem { +public: + // a=rtpmap:111 opus/48000/2 + uint8_t pt; + std::string codec; + uint32_t sample_rate; + uint32_t channel { 0 }; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "rtpmap"; } +}; + +class SdpAttrRtcpFb : public SdpItem { +public: + // a=rtcp-fb:98 nack pli + // a=rtcp-fb:120 nack 支持 nack 重传,nack (Negative-Acknowledgment) 。 [AUTO-TRANSLATED:08d5c4e2] + // a=rtcp-fb:120 nack supports nack retransmission, nack (Negative-Acknowledgment). + // a=rtcp-fb:120 nack pli 支持 nack 关键帧重传,PLI (Picture Loss Indication) 。 [AUTO-TRANSLATED:c331c1dd] + // a=rtcp-fb:120 nack pli supports nack keyframe retransmission, PLI (Picture Loss Indication). + // a=rtcp-fb:120 ccm fir 支持编码层关键帧请求,CCM (Codec Control Message),FIR (Full Intra Request ),通常与 nack pli 有同样的效果,但是 nack pli [AUTO-TRANSLATED:7090fdc9] + // a=rtcp-fb:120 ccm fir supports keyframe requests for the coding layer, CCM (Codec Control Message), FIR (Full Intra Request), which usually has the same effect as nack pli, but nack pli + // 是用于重传时的关键帧请求。 a=rtcp-fb:120 goog-remb 支持 REMB (Receiver Estimated Maximum Bitrate) 。 a=rtcp-fb:120 transport-cc 支持 TCC (Transport [AUTO-TRANSLATED:ffac8e91] + // is used for keyframe requests during retransmission. a=rtcp-fb:120 goog-remb supports REMB (Receiver Estimated Maximum Bitrate). a=rtcp-fb:120 transport-cc supports TCC (Transport + // Congest Control) 。 [AUTO-TRANSLATED:dcf53e31] + // Congest Control). + uint8_t pt; + std::string rtcp_type; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "rtcp-fb"; } +}; + +class SdpAttrFmtp : public SdpItem { +public: + // fmtp:96 level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42e01f + uint8_t pt; + std::map fmtp; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "fmtp"; } +}; + +class SdpAttrSSRC : public SdpItem { +public: + // a=ssrc:3245185839 cname:Cx4i/VTR51etgjT7 + // a=ssrc:3245185839 msid:cb373bff-0fea-4edb-bc39-e49bb8e8e3b9 0cf7e597-36a2-4480-9796-69bf0955eef5 + // a=ssrc:3245185839 mslabel:cb373bff-0fea-4edb-bc39-e49bb8e8e3b9 + // a=ssrc:3245185839 label:0cf7e597-36a2-4480-9796-69bf0955eef5 + // a=ssrc: + // a=ssrc: : + // cname 是必须的,msid/mslabel/label 这三个属性都是 WebRTC 自创的,或者说 Google 自创的,可以参考 https://tools.ietf.org/html/draft-ietf-mmusic-msid-17, [AUTO-TRANSLATED:d8cb1baf] + // cname is required, msid/mslabel/label these three attributes are all created by WebRTC, or Google created, you can refer to https://tools.ietf.org/html/draft-ietf-mmusic-msid-17, + // 理解它们三者的关系需要先了解三个概念:RTP stream / MediaStreamTrack / MediaStream : [AUTO-TRANSLATED:7d385cf5] + // understanding the relationship between the three requires understanding three concepts: RTP stream / MediaStreamTrack / MediaStream: + // 一个 a=ssrc 代表一个 RTP stream ; [AUTO-TRANSLATED:ee1ecc6f] + // One a=ssrc represents one RTP stream; + // 一个 MediaStreamTrack 通常包含一个或多个 RTP stream,例如一个视频 MediaStreamTrack 中通常包含两个 RTP stream,一个用于常规传输,一个用于 nack 重传; [AUTO-TRANSLATED:e8ddf0fd] + // A MediaStreamTrack usually contains one or more RTP streams, for example, a video MediaStreamTrack usually contains two RTP streams, one for regular transmission and one for nack retransmission; + // 一个 MediaStream 通常包含一个或多个 MediaStreamTrack ,例如 simulcast 场景下,一个 MediaStream 通常会包含三个不同编码质量的 MediaStreamTrack ; [AUTO-TRANSLATED:31318d43] + // A MediaStream usually contains one or more MediaStreamTrack, for example, in a simulcast scenario, a MediaStream usually contains three MediaStreamTrack of different encoding quality; + // 这种标记方式并不被 Firefox 认可,在 Firefox 生成的 SDP 中一个 a=ssrc 通常只有一行,例如: [AUTO-TRANSLATED:8c2c424c] + // This marking method is not recognized by Firefox, in the SDP generated by Firefox, one a=ssrc usually has only one line, for example: + // a=ssrc:3245185839 cname:Cx4i/VTR51etgjT7 + + uint32_t ssrc; + std::string attribute; + std::string attribute_value; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "ssrc"; } +}; + +class SdpAttrSSRCGroup : public SdpItem { +public: + // a=ssrc-group 定义参考 RFC 5576(https://tools.ietf.org/html/rfc5576) ,用于描述多个 ssrc 之间的关联,常见的有两种: [AUTO-TRANSLATED:a87cbcc6] + // a=ssrc-group definition refers to RFC 5576(https://tools.ietf.org/html/rfc5576), used to describe the association between multiple ssrcs, there are two common types: + // a=ssrc-group:FID 2430709021 3715850271 + // FID (Flow Identification) 最初用在 FEC 的关联中,WebRTC 中通常用于关联一组常规 RTP stream 和 重传 RTP stream 。 [AUTO-TRANSLATED:f2c0fcbb] + // FID (Flow Identification) was originally used in FEC association, and in WebRTC it is usually used to associate a group of regular RTP streams and retransmission RTP streams. + // a=ssrc-group:SIM 360918977 360918978 360918980 + // 在 Chrome 独有的 SDP munging 风格的 simulcast 中使用,将三组编码质量由低到高的 MediaStreamTrack 关联在一起。 [AUTO-TRANSLATED:61bf7596] + // Used in Chrome's unique SDP munging style simulcast, associating three groups of MediaStreamTrack from low to high encoding quality. + std::string type { "FID" }; + std::vector ssrcs; + + bool isFID() const { return type == "FID"; } + bool isSIM() const { return type == "SIM"; } + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "ssrc-group"; } +}; + +class SdpAttrSctpMap : public SdpItem { +public: + // https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-05 + // a=sctpmap:5000 webrtc-datachannel 1024 + // a=sctpmap: sctpmap-number media-subtypes [streams] + uint16_t port = 0; + std::string subtypes; + uint32_t streams = 0; + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "sctpmap"; } + bool empty() const { return port == 0 && subtypes.empty() && streams == 0; } +}; + +class SdpAttrCandidate : public SdpItem { +public: + using Ptr = std::shared_ptr; + // https://tools.ietf.org/html/rfc5245 + // 15.1. "candidate" Attribute + // a=candidate:4 1 udp 2 192.168.1.7 58107 typ host + // a=candidate:
typ + std::string foundation; + // 传输媒体的类型,1代表RTP;2代表 RTCP。 [AUTO-TRANSLATED:9ec924a6] + // The type of media to be transmitted, 1 represents RTP; 2 represents RTCP. + uint32_t component; + std::string transport { "udp" }; + uint32_t priority; + std::string address; + uint16_t port; + std::string type; + std::vector> arr; + + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "candidate"; } +}; + +class SdpAttrMsid : public SdpItem { +public: + const char *getKey() const override { return "msid"; } +}; + +class SdpAttrExtmapAllowMixed : public SdpItem { +public: + const char *getKey() const override { return "extmap-allow-mixed"; } +}; + +class SdpAttrSimulcast : public SdpItem { +public: + // https://www.meetecho.com/blog/simulcast-janus-ssrc/ + // https://tools.ietf.org/html/draft-ietf-mmusic-sdp-simulcast-14 + const char *getKey() const override { return "simulcast"; } + void parse(const std::string &str) override; + std::string toString() const override; + bool empty() const { return rids.empty(); } + std::string direction; + std::vector rids; +}; + +class SdpAttrRid : public SdpItem { +public: + void parse(const std::string &str) override; + std::string toString() const override; + const char *getKey() const override { return "rid"; } + std::string direction; + std::string rid; +}; + +class RtcSdpBase { +public: + void addItem(SdpItem::Ptr item) { items.push_back(std::move(item)); } + void addAttr(SdpItem::Ptr attr) { + auto item = std::make_shared(); + item->detail = std::move(attr); + items.push_back(std::move(item)); + } + + virtual ~RtcSdpBase() = default; + virtual std::string toString() const; + void toRtsp(); + + RtpDirection getDirection() const; + + template + cls getItemClass(char key, const char *attr_key = nullptr) const { + auto item = std::dynamic_pointer_cast(getItem(key, attr_key)); + if (!item) { + return cls(); + } + return *item; + } + + std::string getStringItem(char key, const char *attr_key = nullptr) const { + auto item = getItem(key, attr_key); + if (!item) { + return ""; + } + return item->toString(); + } + + SdpItem::Ptr getItem(char key, const char *attr_key = nullptr) const; + + template + std::vector getAllItem(char key_c, const char *attr_key = nullptr) const { + std::vector ret; + std::string key(1, key_c); + for (auto item : items) { + if (strcasecmp(item->getKey(), key.data()) == 0) { + if (!attr_key) { + auto c = std::dynamic_pointer_cast(item); + if (c) { + ret.emplace_back(*c); + } + } else { + auto attr = std::dynamic_pointer_cast(item); + if (attr && !strcasecmp(attr->detail->getKey(), attr_key)) { + auto c = std::dynamic_pointer_cast(attr->detail); + if (c) { + ret.emplace_back(*c); + } + } + } + } + } + return ret; + } + +private: + std::vector items; +}; + +class RtcSessionSdp : public RtcSdpBase { +public: + using Ptr = std::shared_ptr; + int getVersion() const; + SdpOrigin getOrigin() const; + std::string getSessionName() const; + std::string getSessionInfo() const; + SdpTime getSessionTime() const; + SdpConnection getConnection() const; + SdpBandwidth getBandwidth() const; + + std::string getUri() const; + std::string getEmail() const; + std::string getPhone() const; + std::string getTimeZone() const; + std::string getEncryptKey() const; + std::string getRepeatTimes() const; + + std::vector medias; + void parse(const std::string &str); + std::string toString() const override; +}; + +////////////////////////////////////////////////////////////////// + +// ssrc相关信息 [AUTO-TRANSLATED:954c641d] +// ssrc related information +class RtcSSRC { +public: + uint32_t ssrc { 0 }; + uint32_t rtx_ssrc { 0 }; + std::string cname; + std::string msid; + std::string mslabel; + std::string label; + + bool empty() const { return ssrc == 0 && cname.empty(); } +}; + +// rtc传输编码方案 [AUTO-TRANSLATED:8b911508] +// rtc transmission encoding scheme +class RtcCodecPlan { +public: + using Ptr = std::shared_ptr; + uint8_t pt; + std::string codec; + uint32_t sample_rate; + // 音频时有效 [AUTO-TRANSLATED:5b230fc8] + // Valid for audio + uint32_t channel = 0; + // rtcp反馈 [AUTO-TRANSLATED:580378bd] + // RTCP feedback + std::set rtcp_fb; + std::map fmtp; + + std::string getFmtp(const char *key) const; +}; + +// rtc 媒体描述 [AUTO-TRANSLATED:b1711a11] +// RTC media description +class RtcMedia { +public: + TrackType type { TrackType::TrackInvalid }; + std::string mid; + uint16_t port { 0 }; + SdpConnection addr; + SdpBandwidth bandwidth; + std::string proto; + RtpDirection direction { RtpDirection::invalid }; + std::vector plan; + + //////// rtp //////// + std::vector rtp_rtx_ssrc; + + //////// simulcast //////// + std::vector rtp_ssrc_sim; + std::vector rtp_rids; + + //////// rtcp //////// + bool rtcp_mux { false }; + bool rtcp_rsize { false }; + SdpAttrRtcp rtcp_addr; + + //////// ice //////// + bool ice_trickle { false }; + bool ice_lite { false }; + bool ice_renomination { false }; + std::string ice_ufrag; + std::string ice_pwd; + std::vector candidate; + + //////// dtls //////// + DtlsRole role { DtlsRole::invalid }; + SdpAttrFingerprint fingerprint; + + //////// extmap //////// + std::vector extmap; + + //////// sctp //////////// + SdpAttrSctpMap sctpmap; + uint32_t sctp_port { 0 }; + + void checkValid() const; + const RtcCodecPlan *getPlan(uint8_t pt) const; + const RtcCodecPlan *getPlan(const char *codec) const; + const RtcCodecPlan *getRelatedRtxPlan(uint8_t pt) const; + uint32_t getRtpSSRC() const; + uint32_t getRtxSSRC() const; + bool supportSimulcast() const; +}; + +class RtcSession { +public: + using Ptr = std::shared_ptr; + + uint32_t version; + SdpOrigin origin; + std::string session_name; + std::string session_info; + SdpTime time; + SdpConnection connection; + SdpAttrMsidSemantic msid_semantic; + std::vector media; + SdpAttrGroup group; + + void loadFrom(const std::string &sdp); + void checkValid() const; + std::string toString() const; + std::string toRtspSdp() const; + const RtcMedia *getMedia(TrackType type) const; + bool supportRtcpFb(const std::string &name, TrackType type = TrackType::TrackVideo) const; + bool supportSimulcast() const; + bool isOnlyDatachannel() const; + +private: + RtcSessionSdp::Ptr toRtcSessionSdp() const; +}; + +class RtcConfigure { +public: + using Ptr = std::shared_ptr; + class RtcTrackConfigure { + public: + bool rtcp_mux; + bool rtcp_rsize; + bool group_bundle; + bool support_rtx; + bool support_red; + bool support_ulpfec; + bool ice_lite; + bool ice_trickle; + bool ice_renomination; + std::string ice_ufrag; + std::string ice_pwd; + + RtpDirection direction { RtpDirection::invalid }; + SdpAttrFingerprint fingerprint; + + std::set rtcp_fb; + std::map extmap; + std::vector preferred_codec; + std::vector candidate; + + void setDefaultSetting(TrackType type); + void enableTWCC(bool enable = true); + void enableREMB(bool enable = true); + }; + + RtcTrackConfigure video; + RtcTrackConfigure audio; + RtcTrackConfigure application; + + void setDefaultSetting(std::string ice_ufrag, std::string ice_pwd, RtpDirection direction, const SdpAttrFingerprint &fingerprint); + void addCandidate(const SdpAttrCandidate &candidate, TrackType type = TrackInvalid); + + std::shared_ptr createOffer() const; + std::shared_ptr createAnswer(const RtcSession &offer) const; + + void setPlayRtspInfo(const std::string &sdp); + + void enableTWCC(bool enable = true, TrackType type = TrackInvalid); + void enableREMB(bool enable = true, TrackType type = TrackInvalid); + +private: + void createMediaOffer(const std::shared_ptr &ret) const; + void createMediaOfferEach(const std::shared_ptr &ret, TrackType type, int index) const; + void matchMedia(const std::shared_ptr &ret, const RtcMedia &media) const; + bool onCheckCodecProfile(const RtcCodecPlan &plan, CodecId codec) const; + void onSelectPlan(RtcCodecPlan &plan, CodecId codec) const; + +private: + RtcCodecPlan::Ptr _rtsp_video_plan; + RtcCodecPlan::Ptr _rtsp_audio_plan; +}; + +class SdpConst { +public: + static std::string const kTWCCRtcpFb; + static std::string const kRembRtcpFb; + +private: + SdpConst() = delete; + ~SdpConst() = delete; +}; + +} // namespace mediakit + +#endif // ZLMEDIAKIT_SDP_H diff --git a/webrtc/SrtpSession.cpp b/webrtc/SrtpSession.cpp index 8aaa368a..408769fc 100644 --- a/webrtc/SrtpSession.cpp +++ b/webrtc/SrtpSession.cpp @@ -204,7 +204,16 @@ SrtpSession::SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_ policy.key = key; // Required for sending RTP retransmission without RTX. policy.allow_repeat_tx = 1; - policy.window_size = 1024; +#if 0 + if (type == Type::OUTBOUND) { + policy.window_size = 0x8000 - 1; + } else { + policy.window_size = 1024; + } +#else + // TODO 关闭防重放攻击 + policy.window_size = 0x8000 - 1; +#endif policy.next = nullptr; // Set the SRTP session. diff --git a/webrtc/SrtpSession.hpp b/webrtc/SrtpSession.hpp index 5e413aea..eed09473 100644 --- a/webrtc/SrtpSession.hpp +++ b/webrtc/SrtpSession.hpp @@ -1,64 +1,65 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef MS_RTC_SRTP_SESSION_HPP -#define MS_RTC_SRTP_SESSION_HPP - -#include "Utils.hpp" - -#include - -typedef struct srtp_ctx_t_ *srtp_t; - -namespace RTC { - -class DepLibSRTP; - -class SrtpSession { -public: - enum class CryptoSuite { - NONE = 0, - AES_CM_128_HMAC_SHA1_80 = 1, - AES_CM_128_HMAC_SHA1_32, - AEAD_AES_256_GCM, - AEAD_AES_128_GCM - }; - -public: - enum class Type { INBOUND = 1, OUTBOUND }; - -public: - SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen); - ~SrtpSession(); - -public: - bool EncryptRtp(uint8_t *data, int *len); - bool DecryptSrtp(uint8_t *data, int *len); - bool EncryptRtcp(uint8_t *data, int *len); - bool DecryptSrtcp(uint8_t *data, int *len); - void RemoveStream(uint32_t ssrc); - -private: - // Allocated by this. - srtp_t session { nullptr }; - std::shared_ptr _env; -}; - -} // namespace RTC - -#endif +/** +ISC License + +Copyright © 2015, Iñaki Baz Castillo + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF +OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#ifndef MS_RTC_SRTP_SESSION_HPP +#define MS_RTC_SRTP_SESSION_HPP + +#include "Util/Byte.hpp" + +#include + +typedef struct srtp_ctx_t_ *srtp_t; + +namespace RTC { + +class DepLibSRTP; + +class SrtpSession { +public: + using Ptr = std::shared_ptr; + enum class CryptoSuite { + NONE = 0, + AES_CM_128_HMAC_SHA1_80 = 1, + AES_CM_128_HMAC_SHA1_32, + AEAD_AES_256_GCM, + AEAD_AES_128_GCM + }; + +public: + enum class Type { INBOUND = 1, OUTBOUND }; + +public: + SrtpSession(Type type, CryptoSuite cryptoSuite, uint8_t *key, size_t keyLen); + ~SrtpSession(); + +public: + bool EncryptRtp(uint8_t *data, int *len); + bool DecryptSrtp(uint8_t *data, int *len); + bool EncryptRtcp(uint8_t *data, int *len); + bool DecryptSrtcp(uint8_t *data, int *len); + void RemoveStream(uint32_t ssrc); + +private: + // Allocated by this. + srtp_t session { nullptr }; + std::shared_ptr _env; +}; + +} // namespace RTC + +#endif diff --git a/webrtc/StunPacket.cpp b/webrtc/StunPacket.cpp index 9f648d55..506b74b3 100644 --- a/webrtc/StunPacket.cpp +++ b/webrtc/StunPacket.cpp @@ -1,881 +1,906 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#define MS_CLASS "RTC::StunPacket" -// #define MS_LOG_DEV_LEVEL 3 +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. +*/ #include "StunPacket.hpp" #include // std::snprintf() #include // std::memcmp(), std::memcpy() #include +#include "Util/logger.h" +#include "Common/macros.h" + +using namespace std; +using namespace toolkit; namespace RTC { - static const uint32_t crc32Table[] = - { - 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, - 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, - 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, - 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, - 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, - 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, - 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, - 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, - 0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, - 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01, - 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, - 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, - 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, - 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, - 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, - 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, - 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, - 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, - 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, - 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, - 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, - 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, - 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, - 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, - 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713, - 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, - 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, - 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, - 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, - 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, - 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, - 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d - }; - inline uint32_t GetCRC32(const uint8_t *data, size_t size) { - uint32_t crc{0xFFFFFFFF}; - const uint8_t *p = data; +static const uint32_t crc32Table[] = { + 0x00000000, 0x77073096, 0xee0e612c, 0x990951ba, 0x076dc419, 0x706af48f, 0xe963a535, 0x9e6495a3, + 0x0edb8832, 0x79dcb8a4, 0xe0d5e91e, 0x97d2d988, 0x09b64c2b, 0x7eb17cbd, 0xe7b82d07, 0x90bf1d91, + 0x1db71064, 0x6ab020f2, 0xf3b97148, 0x84be41de, 0x1adad47d, 0x6ddde4eb, 0xf4d4b551, 0x83d385c7, + 0x136c9856, 0x646ba8c0, 0xfd62f97a, 0x8a65c9ec, 0x14015c4f, 0x63066cd9, 0xfa0f3d63, 0x8d080df5, + 0x3b6e20c8, 0x4c69105e, 0xd56041e4, 0xa2677172, 0x3c03e4d1, 0x4b04d447, 0xd20d85fd, 0xa50ab56b, + 0x35b5a8fa, 0x42b2986c, 0xdbbbc9d6, 0xacbcf940, 0x32d86ce3, 0x45df5c75, 0xdcd60dcf, 0xabd13d59, + 0x26d930ac, 0x51de003a, 0xc8d75180, 0xbfd06116, 0x21b4f4b5, 0x56b3c423, 0xcfba9599, 0xb8bda50f, + 0x2802b89e, 0x5f058808, 0xc60cd9b2, 0xb10be924, 0x2f6f7c87, 0x58684c11, 0xc1611dab, 0xb6662d3d, + 0x76dc4190, 0x01db7106, 0x98d220bc, 0xefd5102a, 0x71b18589, 0x06b6b51f, 0x9fbfe4a5, 0xe8b8d433, + 0x7807c9a2, 0x0f00f934, 0x9609a88e, 0xe10e9818, 0x7f6a0dbb, 0x086d3d2d, 0x91646c97, 0xe6635c01, + 0x6b6b51f4, 0x1c6c6162, 0x856530d8, 0xf262004e, 0x6c0695ed, 0x1b01a57b, 0x8208f4c1, 0xf50fc457, + 0x65b0d9c6, 0x12b7e950, 0x8bbeb8ea, 0xfcb9887c, 0x62dd1ddf, 0x15da2d49, 0x8cd37cf3, 0xfbd44c65, + 0x4db26158, 0x3ab551ce, 0xa3bc0074, 0xd4bb30e2, 0x4adfa541, 0x3dd895d7, 0xa4d1c46d, 0xd3d6f4fb, + 0x4369e96a, 0x346ed9fc, 0xad678846, 0xda60b8d0, 0x44042d73, 0x33031de5, 0xaa0a4c5f, 0xdd0d7cc9, + 0x5005713c, 0x270241aa, 0xbe0b1010, 0xc90c2086, 0x5768b525, 0x206f85b3, 0xb966d409, 0xce61e49f, + 0x5edef90e, 0x29d9c998, 0xb0d09822, 0xc7d7a8b4, 0x59b33d17, 0x2eb40d81, 0xb7bd5c3b, 0xc0ba6cad, + 0xedb88320, 0x9abfb3b6, 0x03b6e20c, 0x74b1d29a, 0xead54739, 0x9dd277af, 0x04db2615, 0x73dc1683, + 0xe3630b12, 0x94643b84, 0x0d6d6a3e, 0x7a6a5aa8, 0xe40ecf0b, 0x9309ff9d, 0x0a00ae27, 0x7d079eb1, + 0xf00f9344, 0x8708a3d2, 0x1e01f268, 0x6906c2fe, 0xf762575d, 0x806567cb, 0x196c3671, 0x6e6b06e7, + 0xfed41b76, 0x89d32be0, 0x10da7a5a, 0x67dd4acc, 0xf9b9df6f, 0x8ebeeff9, 0x17b7be43, 0x60b08ed5, + 0xd6d6a3e8, 0xa1d1937e, 0x38d8c2c4, 0x4fdff252, 0xd1bb67f1, 0xa6bc5767, 0x3fb506dd, 0x48b2364b, + 0xd80d2bda, 0xaf0a1b4c, 0x36034af6, 0x41047a60, 0xdf60efc3, 0xa867df55, 0x316e8eef, 0x4669be79, + 0xcb61b38c, 0xbc66831a, 0x256fd2a0, 0x5268e236, 0xcc0c7795, 0xbb0b4703, 0x220216b9, 0x5505262f, + 0xc5ba3bbe, 0xb2bd0b28, 0x2bb45a92, 0x5cb36a04, 0xc2d7ffa7, 0xb5d0cf31, 0x2cd99e8b, 0x5bdeae1d, + 0x9b64c2b0, 0xec63f226, 0x756aa39c, 0x026d930a, 0x9c0906a9, 0xeb0e363f, 0x72076785, 0x05005713, + 0x95bf4a82, 0xe2b87a14, 0x7bb12bae, 0x0cb61b38, 0x92d28e9b, 0xe5d5be0d, 0x7cdcefb7, 0x0bdbdf21, + 0x86d3d2d4, 0xf1d4e242, 0x68ddb3f8, 0x1fda836e, 0x81be16cd, 0xf6b9265b, 0x6fb077e1, 0x18b74777, + 0x88085ae6, 0xff0f6a70, 0x66063bca, 0x11010b5c, 0x8f659eff, 0xf862ae69, 0x616bffd3, 0x166ccf45, + 0xa00ae278, 0xd70dd2ee, 0x4e048354, 0x3903b3c2, 0xa7672661, 0xd06016f7, 0x4969474d, 0x3e6e77db, + 0xaed16a4a, 0xd9d65adc, 0x40df0b66, 0x37d83bf0, 0xa9bcae53, 0xdebb9ec5, 0x47b2cf7f, 0x30b5ffe9, + 0xbdbdf21c, 0xcabac28a, 0x53b39330, 0x24b4a3a6, 0xbad03605, 0xcdd70693, 0x54de5729, 0x23d967bf, + 0xb3667a2e, 0xc4614ab8, 0x5d681b02, 0x2a6f2b94, 0xb40bbe37, 0xc30c8ea1, 0x5a05df1b, 0x2d02ef8d +}; - while (size--) { - crc = crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8); - } +inline uint32_t getCRC32(const uint8_t *data, size_t size) { + uint32_t crc { 0xFFFFFFFF }; + const uint8_t *p = data; - return crc ^ ~0U; + while (size--) { + crc = crc32Table[(crc ^ *p++) & 0xFF] ^ (crc >> 8); } - static std::string openssl_HMACsha1(const void *key, size_t key_len, const void *data, size_t data_len){ - std::string str; - str.resize(20); - unsigned int out_len; + return crc ^ ~0U; +} + +static std::string openssl_HMACsha1(const void *key, size_t key_len, const void *data, size_t data_len) { + std::string str; + str.resize(20); + unsigned int out_len; #if defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) - //openssl 1.1.0新增api,老版本api作废 - HMAC_CTX *ctx = HMAC_CTX_new(); - HMAC_CTX_reset(ctx); - HMAC_Init_ex(ctx, key, (int)key_len, EVP_sha1(), NULL); - HMAC_Update(ctx, (unsigned char*)data, data_len); - HMAC_Final(ctx, (unsigned char *)str.data(), &out_len); - HMAC_CTX_reset(ctx); - HMAC_CTX_free(ctx); + // openssl 1.1.0新增api,老版本api作废 + HMAC_CTX *ctx = HMAC_CTX_new(); + HMAC_CTX_reset(ctx); + HMAC_Init_ex(ctx, key, (int)key_len, EVP_sha1(), NULL); + HMAC_Update(ctx, (unsigned char *)data, data_len); + HMAC_Final(ctx, (unsigned char *)str.data(), &out_len); + HMAC_CTX_reset(ctx); + HMAC_CTX_free(ctx); #else - HMAC_CTX ctx; - HMAC_CTX_init(&ctx); - HMAC_Init_ex(&ctx, key, key_len, EVP_sha1(), NULL); - HMAC_Update(&ctx, (unsigned char*)data, data_len); - HMAC_Final(&ctx, (unsigned char *)str.data(), &out_len); - HMAC_CTX_cleanup(&ctx); -#endif //defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) - return str; + HMAC_CTX ctx; + HMAC_CTX_init(&ctx); + HMAC_Init_ex(&ctx, key, key_len, EVP_sha1(), NULL); + HMAC_Update(&ctx, (unsigned char *)data, data_len); + HMAC_Final(&ctx, (unsigned char *)str.data(), &out_len); + HMAC_CTX_cleanup(&ctx); +#endif // defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) + return str; +} + +static std::string openssl_MD5(const void *data, size_t data_len) { + std::string str; + str.resize(16); + unsigned int out_len; +#if defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) + // openssl 1.1.0新增api,老版本api作废 + EVP_MD_CTX *ctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(ctx, EVP_md5(), NULL); + EVP_DigestUpdate(ctx, data, data_len); + unsigned int md_len; + EVP_DigestFinal_ex(ctx, (unsigned char *)str.data(), &md_len); + EVP_MD_CTX_free(ctx); +#else + MD5_CTX ctx; + MD5_Init(&ctx); + MD5_Update(&ctx, data, data_len); + MD5_Final((unsigned char *)str.data(), &ctx); +#endif // defined(OPENSSL_VERSION_NUMBER) && (OPENSSL_VERSION_NUMBER > 0x10100000L) + return str; +} + +/////////////////////////////////////////////////// +// StunAttribute + +bool StunAttribute::isComprehensionRequired(const uint8_t *data, size_t len) { + return ((data[0] & 0xC0) == 0x00); +} + +const uint8_t * StunAttribute::loadHeader(const uint8_t *buf) { + _type = (Type)Byte::Get2Bytes(buf, 0); + _length = Byte::Get2Bytes(buf, 2); + return buf + ATTR_HEADER_SIZE; +} + +uint8_t * StunAttribute::storeHeader() { + _data = toolkit::BufferRaw::create(ATTR_HEADER_SIZE + Byte::PadTo4Bytes(_length)); + _data->setSize(_data->getCapacity()); + memset(_data->data(), 0, _data->size()); + uint8_t *ptr = (uint8_t *)_data->data(); + Byte::Set2Bytes(ptr, 0, (uint16_t)_type); + Byte::Set2Bytes(ptr, 2, _length); + return ptr + ATTR_HEADER_SIZE; +} + +bool StunAttrMappedAddress::loadFromData(const uint8_t *buf, size_t len) { + StunAttribute::loadHeader(buf); + return true; +} + +bool StunAttrMappedAddress::storeToData() { + return true; +} + +bool StunAttrUserName::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _username.assign((const char *)p, _length); + return true; +} + +bool StunAttrUserName::storeToData() { + _length = _username.length(); + auto p = StunAttribute::storeHeader(); + memcpy(p, _username.data(), _username.length()); + return true; +} + +bool StunAttrMessageIntegrity::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _hmac.assign((const char *)p, _length); + return true; +} + +bool StunAttrMessageIntegrity::storeToData() { + _length = _hmac.size(); + auto p = StunAttribute::storeHeader(); + memcpy(p, _hmac.data(), _hmac.size()); + return true; +} + +bool StunAttrErrorCode::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _error_code = (Code)(p[2] * 100 + p[3]); + return true; +} + +bool StunAttrErrorCode::storeToData() { + _length = 4; + auto p = StunAttribute::storeHeader(); + Byte::Set2Bytes(p, 0, 0); // reserved + uint16_t code = (uint16_t)_error_code; + p[2] = code / 100; + p[3] = code % 100; + return true; +} + +bool StunAttrChannelNumber::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _channel_number = Byte::Get2Bytes(p, 0); + return true; +} + +bool StunAttrChannelNumber::storeToData() { + _length = 4; + auto p = StunAttribute::storeHeader(); + Byte::Set2Bytes(p, 0, _channel_number); + Byte::Set2Bytes(p, 2, 0); // RFFU + return true; +} + +bool StunAttrLifeTime::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _lifetime = Byte::Get4Bytes(p, 0); + return true; +} + +bool StunAttrLifeTime::storeToData() { + _length = 4; + auto p = StunAttribute::storeHeader(); + Byte::Set4Bytes(p, 0, _lifetime); + return true; +} + +bool StunAttrXorPeerAddress::loadFromData(const uint8_t *buf, size_t len) { + auto attrValue = StunAttribute::loadHeader(buf); + memset(&_addr, 0, sizeof(_addr)); + uint8_t port[2], addr[16]; + port[0] = attrValue[2] ^ StunPacket::_magicCookie[0]; + port[1] = attrValue[3] ^ StunPacket::_magicCookie[1]; + addr[0] = attrValue[4] ^ StunPacket::_magicCookie[0]; + addr[1] = attrValue[5] ^ StunPacket::_magicCookie[1]; + addr[2] = attrValue[6] ^ StunPacket::_magicCookie[2]; + addr[3] = attrValue[7] ^ StunPacket::_magicCookie[3]; + auto protocol = attrValue[1]; + if (protocol == 0x01) { + _addr.ss_family = AF_INET; + struct sockaddr_in *ipv4 = (struct sockaddr_in *)&_addr; + ipv4->sin_port = ntohs(Byte::Get2Bytes(port, 0)); + std::memcpy((void *)&(reinterpret_cast(&_addr))->sin_addr.s_addr, addr, 4); + } else { + _addr.ss_family = AF_INET6; + for (int i=0; i < 12; ++i) { + addr[i + 4] = attrValue[i + 8] ^ _transaction_id[i]; + } + struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)&_addr; + ipv6->sin6_port = ntohs(Byte::Get2Bytes(port, 0)); + std::memcpy((void *)&(reinterpret_cast(&_addr))->sin6_addr.s6_addr, addr, 16); } - /* Class variables. */ + return true; +} - const uint8_t StunPacket::magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; +bool StunAttrXorPeerAddress::storeToData() { + _length = (_addr.ss_family == AF_INET) ? 8 : 20; + auto attrValue = StunAttribute::storeHeader(); + // Set first byte to 0. + attrValue[0] = 0; + if (_addr.ss_family == AF_INET) { + // Set inet family. + attrValue[1] = 1; + // Set port and XOR it. + std::memcpy(attrValue + 2, &(reinterpret_cast(&_addr))->sin_port, 2); + attrValue[2] ^= StunPacket::_magicCookie[0]; + attrValue[3] ^= StunPacket::_magicCookie[1]; - /* Class methods. */ + // Set address and XOR it. + std::memcpy(attrValue + 4, &(reinterpret_cast(&_addr))->sin_addr.s_addr, 4); + attrValue[4] ^= StunPacket::_magicCookie[0]; + attrValue[5] ^= StunPacket::_magicCookie[1]; + attrValue[6] ^= StunPacket::_magicCookie[2]; + attrValue[7] ^= StunPacket::_magicCookie[3]; + } else if (_addr.ss_family == AF_INET6) { + // Set inet family. + attrValue[1] = 2; - StunPacket* StunPacket::Parse(const uint8_t* data, size_t len) - { - MS_TRACE(); - - if (!StunPacket::IsStun(data, len)) - return nullptr; - - /* - The message type field is decomposed further into the following - structure: - - 0 1 - 2 3 4 5 6 7 8 9 0 1 2 3 4 5 - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - |M |M |M|M|M|C|M|M|M|C|M|M|M|M| - |11|10|9|8|7|1|6|5|4|0|3|2|1|0| - +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ - - Figure 3: Format of STUN Message Type Field - - Here the bits in the message type field are shown as most significant - (M11) through least significant (M0). M11 through M0 represent a 12- - bit encoding of the method. C1 and C0 represent a 2-bit encoding of - the class. - */ - - // Get type field. - uint16_t msgType = Utils::Byte::Get2Bytes(data, 0); - - // Get length field. - uint16_t msgLength = Utils::Byte::Get2Bytes(data, 2); - - // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. - if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) - { - MS_WARN_TAG( - ice, - "length field + 20 does not match total size (or it is not multiple of 4 bytes), " - "packet discarded"); - - return nullptr; + std::memcpy(attrValue + 2, &(reinterpret_cast(&_addr))->sin6_port, 2); + attrValue[2] ^= StunPacket::_magicCookie[0]; + attrValue[3] ^= StunPacket::_magicCookie[1]; + // Set address and XOR it. + std::memcpy(attrValue + 4, &(reinterpret_cast(&_addr))->sin6_addr.s6_addr, 16); + attrValue[4] ^= StunPacket::_magicCookie[0]; + attrValue[5] ^= StunPacket::_magicCookie[1]; + attrValue[6] ^= StunPacket::_magicCookie[2]; + attrValue[7] ^= StunPacket::_magicCookie[3]; + for (int i=0; i < 12; ++i) { + attrValue[8 + i] ^= _transaction_id[i]; } - - // Get STUN method. - uint16_t msgMethod = (msgType & 0x000f) | ((msgType & 0x00e0) >> 1) | ((msgType & 0x3E00) >> 2); - - // Get STUN class. - uint16_t msgClass = ((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4); - - // Create a new StunPacket (data + 8 points to the received TransactionID field). - auto* packet = new StunPacket( - static_cast(msgClass), static_cast(msgMethod), data + 8, data, len); - - /* - STUN Attributes - - After the STUN header are zero or more attributes. Each attribute - MUST be TLV encoded, with a 16-bit type, 16-bit length, and value. - Each STUN attribute MUST end on a 32-bit boundary. As mentioned - above, all fields in an attribute are transmitted most significant - bit first. - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Type | Length | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Value (variable) .... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - */ - - // Start looking for attributes after STUN header (Byte #20). - size_t pos{ 20 }; - // Flags (positions) for special MESSAGE-INTEGRITY and FINGERPRINT attributes. - bool hasMessageIntegrity{ false }; - bool hasFingerprint{ false }; - size_t fingerprintAttrPos; // Will point to the beginning of the attribute. - uint32_t fingerprint; // Holds the value of the FINGERPRINT attribute. - - // Ensure there are at least 4 remaining bytes (attribute with 0 length). - while (pos + 4 <= len) - { - // Get the attribute type. - auto attrType = static_cast(Utils::Byte::Get2Bytes(data, pos)); - - // Get the attribute length. - uint16_t attrLength = Utils::Byte::Get2Bytes(data, pos + 2); - - // Ensure the attribute length is not greater than the remaining size. - if ((pos + 4 + attrLength) > len) - { - MS_WARN_TAG(ice, "the attribute length exceeds the remaining size, packet discarded"); - - delete packet; - return nullptr; - } - - // FINGERPRINT must be the last attribute. - if (hasFingerprint) - { - MS_WARN_TAG(ice, "attribute after FINGERPRINT is not allowed, packet discarded"); - - delete packet; - return nullptr; - } - - // After a MESSAGE-INTEGRITY attribute just FINGERPRINT is allowed. - if (hasMessageIntegrity && attrType != Attribute::FINGERPRINT) - { - MS_WARN_TAG( - ice, - "attribute after MESSAGE-INTEGRITY other than FINGERPRINT is not allowed, " - "packet discarded"); - - delete packet; - return nullptr; - } - - const uint8_t* attrValuePos = data + pos + 4; - - switch (attrType) - { - case Attribute::USERNAME: - { - packet->SetUsername( - reinterpret_cast(attrValuePos), static_cast(attrLength)); - - break; - } - - case Attribute::PRIORITY: - { - // Ensure attribute length is 4 bytes. - if (attrLength != 4) - { - MS_WARN_TAG(ice, "attribute PRIORITY must be 4 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetPriority(Utils::Byte::Get4Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::ICE_CONTROLLING: - { - // Ensure attribute length is 8 bytes. - if (attrLength != 8) - { - MS_WARN_TAG(ice, "attribute ICE-CONTROLLING must be 8 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetIceControlling(Utils::Byte::Get8Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::ICE_CONTROLLED: - { - // Ensure attribute length is 8 bytes. - if (attrLength != 8) - { - MS_WARN_TAG(ice, "attribute ICE-CONTROLLED must be 8 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetIceControlled(Utils::Byte::Get8Bytes(attrValuePos, 0)); - - break; - } - - case Attribute::USE_CANDIDATE: - { - // Ensure attribute length is 0 bytes. - if (attrLength != 0) - { - MS_WARN_TAG(ice, "attribute USE-CANDIDATE must be 0 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - packet->SetUseCandidate(); - - break; - } - - case Attribute::MESSAGE_INTEGRITY: - { - // Ensure attribute length is 20 bytes. - if (attrLength != 20) - { - MS_WARN_TAG(ice, "attribute MESSAGE-INTEGRITY must be 20 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - hasMessageIntegrity = true; - packet->SetMessageIntegrity(attrValuePos); - - break; - } - - case Attribute::FINGERPRINT: - { - // Ensure attribute length is 4 bytes. - if (attrLength != 4) - { - MS_WARN_TAG(ice, "attribute FINGERPRINT must be 4 bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - hasFingerprint = true; - fingerprintAttrPos = pos; - fingerprint = Utils::Byte::Get4Bytes(attrValuePos, 0); - packet->SetFingerprint(); - - break; - } - - case Attribute::ERROR_CODE: - { - // Ensure attribute length >= 4bytes. - if (attrLength < 4) - { - MS_WARN_TAG(ice, "attribute ERROR-CODE must be >= 4bytes length, packet discarded"); - - delete packet; - return nullptr; - } - - uint8_t errorClass = Utils::Byte::Get1Byte(attrValuePos, 2); - uint8_t errorNumber = Utils::Byte::Get1Byte(attrValuePos, 3); - auto errorCode = static_cast(errorClass * 100 + errorNumber); - - packet->SetErrorCode(errorCode); - - break; - } - - default:; - } - - // Set next attribute position. - pos = - static_cast(Utils::Byte::PadTo4Bytes(static_cast(pos + 4 + attrLength))); - } - - // Ensure current position matches the total length. - if (pos != len) - { - MS_WARN_TAG(ice, "computed packet size does not match total size, packet discarded"); - - delete packet; - return nullptr; - } - - // If it has FINGERPRINT attribute then verify it. - if (hasFingerprint) - { - // Compute the CRC32 of the received packet up to (but excluding) the - // FINGERPRINT attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = GetCRC32(data, fingerprintAttrPos) ^ 0x5354554e; - - // Compare with the FINGERPRINT value in the packet. - if (fingerprint != computedFingerprint) - { - MS_WARN_TAG( - ice, - "computed FINGERPRINT value does not match the value in the packet, " - "packet discarded"); - - delete packet; - return nullptr; - } - } - - return packet; } - /* Instance methods. */ + return true; +} - StunPacket::StunPacket( - Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size) - : klass(klass), method(method), transactionId(transactionId), data(const_cast(data)), - size(size) - { - MS_TRACE(); +bool StunAttrData::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + setData((const char *)p, _length); + return true; +} + +bool StunAttrData::storeToData() { + _length = _data_content.size(); + auto p = StunAttribute::storeHeader(); + memcpy(p, _data_content.data(), _data_content.size()); + return true; +} + +bool StunAttrRealm::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _realm.assign((const char *)p, _length); + return true; +} + +bool StunAttrRealm::storeToData() { + _length = _realm.size(); + auto p = StunAttribute::storeHeader(); + memcpy(p, _realm.data(), _realm.size()); + return true; +} + +bool StunAttrNonce::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _nonce.assign((const char *)p, _length); + return true; +} + +bool StunAttrNonce::storeToData() { + _length = _nonce.size(); + auto p = StunAttribute::storeHeader(); + memcpy(p, _nonce.data(), _nonce.size()); + return true; +} + +bool StunAttrRequestedTransport::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _protocol = (Protocol)p[0]; + return true; +} + +bool StunAttrRequestedTransport::storeToData() { + _length = 4; + auto p = StunAttribute::storeHeader(); + p[0] = (uint8_t)_protocol; + return true; +} + +bool StunAttrPriority::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _priority = Byte::Get4Bytes(p, 0); + return true; +} + +bool StunAttrPriority::storeToData() { + _length = 4; + auto p = StunAttribute::storeHeader(); + Byte::Set4Bytes(p, 0, _priority); + return true; +} + +bool StunAttrUseCandidate::loadFromData(const uint8_t *buf, size_t len) { + StunAttribute::loadHeader(buf); + return true; +} + +bool StunAttrUseCandidate::storeToData() { + _length = 0; + StunAttribute::storeHeader(); + return true; +} + +bool StunAttrFingerprint::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _fingerprint = Byte::Get4Bytes(p, 0); + return true; +} + +bool StunAttrFingerprint::storeToData() { + _length = 4; + auto p = StunAttribute::storeHeader(); + Byte::Set4Bytes(p, 0, _fingerprint); + return true; +} + +bool StunAttrIceControlled::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _tiebreaker = Byte::Get8Bytes(p, 0); + return true; +} + +bool StunAttrIceControlled::storeToData() { + _length = 8; + auto p = StunAttribute::storeHeader(); + Byte::Set8Bytes(p, 0, _tiebreaker); + return true; +} + +bool StunAttrIceControlling::loadFromData(const uint8_t *buf, size_t len) { + auto p = StunAttribute::loadHeader(buf); + _tiebreaker = Byte::Get8Bytes(p, 0); + return true; +} + +bool StunAttrIceControlling::storeToData() { + _length = 8; + auto p = StunAttribute::storeHeader(); + Byte::Set8Bytes(p, 0, _tiebreaker); + return true; +} + +/////////////////////////////////////////////////// +// StunPacket + +const uint8_t StunPacket::_magicCookie[] = { 0x21, 0x12, 0xA4, 0x42 }; + +/* Class methods. */ +bool StunPacket::isStun(const uint8_t *data, size_t len) { + // reference https://www.rfc-editor.org/rfc/rfc8489.html#section-6.3 + return + // STUN headers are 20 bytes. + (len >= 20) && + // checks that the first two bits are 0 + ((data[0] & 0xC0) == 0) && + // that the Magic Cookie field has the correct value + (data[4] == StunPacket::_magicCookie[0]) && (data[5] == StunPacket::_magicCookie[1]) && (data[6] == StunPacket::_magicCookie[2]) + && (data[7] == StunPacket::_magicCookie[3]); +} + +/* + The message type field is decomposed further into the following + structure: + + 0 1 + 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + |M |M |M|M|M|C|M|M|M|C|M|M|M|M| + |11|10|9|8|7|1|6|5|4|0|3|2|1|0| + +--+--+-+-+-+-+-+-+-+-+-+-+-+-+ + + Figure 3: Format of STUN Message Type Field + + Here the bits in the message type field are shown as most significant + (M11) through least significant (M0). M11 through M0 represent a 12- + bit encoding of the method. C1 and C0 represent a 2-bit encoding of + the class. +*/ +StunPacket::Class StunPacket::getClass(const uint8_t *data, size_t len) { + return StunPacket::Class(((data[0] & 0x01) << 1) | ((data[1] & 0x10) >> 4)); +} + +StunPacket::Method StunPacket::getMethod(const uint8_t *data, size_t len) { + uint16_t msgType = Byte::Get2Bytes(data, 0); + return StunPacket::Method((msgType & 0x000F) | ((msgType & 0x00E0) >> 1) | ((msgType & 0x3E00) >> 2)); +} + +StunPacket::Ptr StunPacket::parse(const uint8_t *data, size_t len) { + // TraceL; + + if (!StunPacket::isStun(data, len)) { + return nullptr; } - StunPacket::~StunPacket() - { - MS_TRACE(); + // Get length field. + uint16_t msgLength = Byte::Get2Bytes(data, 2); + + // length field must be total size minus header's 20 bytes, and must be multiple of 4 Bytes. + if ((static_cast(msgLength) != len - 20) || ((msgLength & 0x03) != 0)) { + WarnL << "length field + 20 does not match total size (or it is not multiple of 4 bytes), packet discarded"; + return nullptr; } -#if 0 - void StunPacket::Dump() const - { - MS_TRACE(); + auto msgMethod = getMethod(data, len); + auto msgClass = getClass(data, len); - MS_DUMP(""); + auto packet = std::make_shared(msgClass, msgMethod, (const char *)data + 8); + packet->loadFromData(data, len); + return packet; +} - std::string klass; - switch (this->klass) - { - case Class::REQUEST: - klass = "Request"; - break; - case Class::INDICATION: - klass = "Indication"; - break; - case Class::SUCCESS_RESPONSE: - klass = "SuccessResponse"; - break; - case Class::ERROR_RESPONSE: - klass = "ErrorResponse"; - break; - } - if (this->method == Method::BINDING) - { - MS_DUMP(" Binding %s", klass.c_str()); - } - else - { - // This prints the unknown method number. Example: TURN Allocate => 0x003. - MS_DUMP(" %s with unknown method %#.3x", klass.c_str(), static_cast(this->method)); - } - MS_DUMP(" size: %zu bytes", this->size); +std::string StunPacket::mappingClassEnum2Str(Class klass) { + switch (klass) { + case StunPacket::Class::REQUEST: return "REQUEST"; + case StunPacket::Class::INDICATION: return "INDICATION"; + case StunPacket::Class::SUCCESS_RESPONSE: return "SUCCESS_RESPONSE"; + case StunPacket::Class::ERROR_RESPONSE: return "ERROR_RESPONSE"; + default: break; + } + return "invalid"; +} - static char transactionId[25]; +std::string StunPacket::mappingMethodEnum2Str(Method method) { + switch (method) { + case StunPacket::Method::BINDING: return "BINDING"; + case StunPacket::Method::ALLOCATE: return "ALLOCATE"; + case StunPacket::Method::REFRESH: return "REFRESH"; + case StunPacket::Method::SEND: return "SEND"; + case StunPacket::Method::DATA: return "DATA"; + case StunPacket::Method::CREATEPERMISSION: return "CREATEPERMISSION"; + case StunPacket::Method::CHANNELBIND: return "CHANNELBIND"; + default: break; + } + return "invalid"; +} - for (int i{ 0 }; i < 12; ++i) - { - // NOTE: n must be 3 because snprintf adds a \0 after printed chars. - std::snprintf(transactionId + (i * 2), 3, "%.2x", this->transactionId[i]); - } - MS_DUMP(" transactionId: %s", transactionId); - if (this->errorCode != 0u) - MS_DUMP(" errorCode: %" PRIu16, this->errorCode); - if (!this->username.empty()) - MS_DUMP(" username: %s", this->username.c_str()); - if (this->priority != 0u) - MS_DUMP(" priority: %" PRIu32, this->priority); - if (this->iceControlling != 0u) - MS_DUMP(" iceControlling: %" PRIu64, this->iceControlling); - if (this->iceControlled != 0u) - MS_DUMP(" iceControlled: %" PRIu64, this->iceControlled); - if (this->hasUseCandidate) - MS_DUMP(" useCandidate"); - if (this->xorMappedAddress != nullptr) - { - int family; - uint16_t port; - std::string ip; +StunPacket::StunPacket(Class klass, Method method, const char* transId) + : _klass(klass) + , _method(method) { + // TraceL; + if (transId) { + _transaction_id.assign(transId, 12); + } else { + refreshTransactionId(); + } +} - Utils::IP::GetAddressInfo(this->xorMappedAddress, family, ip, port); +StunPacket::~StunPacket() { + // TraceL; +} - MS_DUMP(" xorMappedAddress: %s : %" PRIu16, ip.c_str(), port); - } - if (this->messageIntegrity != nullptr) - { - static char messageIntegrity[41]; +std::string StunPacket::dumpString(bool transId) const { + std::string ret = "class=" + getClassStr() + ", method=" + getMethodStr(); + if (transId) { + ret += ", transaction=" + hexdump(_transaction_id.data(), _transaction_id.size()); + } + return ret; +} - for (int i{ 0 }; i < 20; ++i) - { - std::snprintf(messageIntegrity + (i * 2), 3, "%.2x", this->messageIntegrity[i]); +void StunPacket::addAttribute(StunAttribute::Ptr attr) { + _attribute_map.emplace(attr->type(), std::move(attr)); +} + +void StunPacket::removeAttribute(StunAttribute::Type type) { + _attribute_map.erase(type); +} + +bool StunPacket::hasAttribute(StunAttribute::Type type) const { + return _attribute_map.count(type) > 0; +} + +StunAttribute::Ptr StunPacket::getAttribute(StunAttribute::Type type) const { + auto it = _attribute_map.find(type); + if (it != _attribute_map.end()) { + return it->second; + } + return nullptr; +} + +std::string StunPacket::getUsername() const { + auto attr = getAttribute(); + return attr ? attr->getUsername() : ""; +} + +uint64_t StunPacket::getPriority() const { + auto attr = getAttribute(); + return attr ? attr->getPriority() : 0; +} + +StunAttrErrorCode::Code StunPacket::getErrorCode() const { + auto attr = getAttribute(); + return attr ? attr->getErrorCode() : StunAttrErrorCode::Code::Invalid; +} + +StunPacket::Authentication StunPacket::checkAuthentication(const std::string &ufrag, const std::string &password) const { + // TraceL; + auto attr_message_integrity = getAttribute(); + switch (_klass) { + case Class::REQUEST: { + if (!attr_message_integrity) { + return Authentication::UNAUTHORIZED; } - MS_DUMP(" messageIntegrity: %s", messageIntegrity); - } - if (this->hasFingerprint) - MS_DUMP(" has fingerprint"); - - MS_DUMP(""); - } -#endif - - StunPacket::Authentication StunPacket::CheckAuthentication( - const std::string& localUsername, const std::string& localPassword) - { - MS_TRACE(); - - switch (this->klass) - { - case Class::REQUEST: - case Class::INDICATION: - { - // Both USERNAME and MESSAGE-INTEGRITY must be present. - if (!this->messageIntegrity || this->username.empty()) - return Authentication::BAD_REQUEST; - - // Check that USERNAME attribute begins with our local username plus ":". - size_t localUsernameLen = localUsername.length(); - - if ( - this->username.length() <= localUsernameLen || this->username.at(localUsernameLen) != ':' || - (this->username.compare(0, localUsernameLen, localUsername) != 0)) - { + if (getMethod() == Method::ALLOCATE || getMethod() == Method::REFRESH || + getMethod() == Method::CREATEPERMISSION || getMethod() == Method::CHANNELBIND) { + // TURN认证:USERNAME应该等于ufrag + std::string username = getUsername(); + if (username != ufrag) { + TraceL << "TURN USERNAME validation failed, expected: " << ufrag << ", got: " << username; return Authentication::UNAUTHORIZED; } - - break; - } - // This method cannot check authentication in received responses (as we - // are ICE-Lite and don't generate requests). - case Class::SUCCESS_RESPONSE: - case Class::ERROR_RESPONSE: - { - MS_ERROR("cannot check authentication for a STUN response"); - - return Authentication::BAD_REQUEST; + } else { + // ICE认证:USERNAME格式为 local-ufrag:remote-ufrag(仅用于ICE BINDING请求) + std::string username = getUsername(); + if (!username.empty()) { + size_t localUsernameLen = ufrag.length(); + if (username.length() <= localUsernameLen || username.at(localUsernameLen) != ':' || + (username.compare(0, localUsernameLen, ufrag) != 0)) { + DebugL << "ICE USERNAME format validation failed, expected format: " << ufrag << ":remote-ufrag, got: " << username; + return Authentication::UNAUTHORIZED; + } + } } + break; } + // This method cannot check authentication in received responses (as we + // are ICE-Lite and don't generate requests). + case Class::INDICATION: return Authentication::OK; + case Class::SUCCESS_RESPONSE: + case Class::ERROR_RESPONSE: break; + } + if (attr_message_integrity) { // If there is FINGERPRINT it must be discarded for MESSAGE-INTEGRITY calculation, // so the header length field must be modified (and later restored). - if (this->hasFingerprint) + if (hasAttribute(StunAttribute::Type::FINGERPRINT)) { // Set the header length field: full size - header length (20) - FINGERPRINT length (8). - Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20 - 8)); - - // Calculate the HMAC-SHA1 of the message according to MESSAGE-INTEGRITY rules. - auto computedMessageIntegrity = openssl_HMACsha1( - localPassword.data(),localPassword.size(), this->data, (this->messageIntegrity - 4) - this->data); - - Authentication result; - - // Compare the computed HMAC-SHA1 with the MESSAGE-INTEGRITY in the packet. - if (std::memcmp(this->messageIntegrity, computedMessageIntegrity.data(), computedMessageIntegrity.size()) == 0) - result = Authentication::OK; - else - result = Authentication::UNAUTHORIZED; - - // Restore the header length field. - if (this->hasFingerprint) - Utils::Byte::Set2Bytes(this->data, 2, static_cast(this->size - 20)); - - return result; - } - - StunPacket* StunPacket::CreateSuccessResponse() - { - MS_TRACE(); - - MS_ASSERT( - this->klass == Class::REQUEST, - "attempt to create a success response for a non Request STUN packet"); - - return new StunPacket(Class::SUCCESS_RESPONSE, this->method, this->transactionId, nullptr, 0); - } - - StunPacket* StunPacket::CreateErrorResponse(uint16_t errorCode) - { - MS_TRACE(); - - MS_ASSERT( - this->klass == Class::REQUEST, - "attempt to create an error response for a non Request STUN packet"); - - auto* response = - new StunPacket(Class::ERROR_RESPONSE, this->method, this->transactionId, nullptr, 0); - - response->SetErrorCode(errorCode); - - return response; - } - - void StunPacket::Authenticate(const std::string& password) - { - // Just for Request, Indication and SuccessResponse messages. - if (this->klass == Class::ERROR_RESPONSE) - { - MS_ERROR("cannot set password for ErrorResponse messages"); - - return; + Byte::Set2Bytes((uint8_t *)_data->data(), 2, _data->size() - HEADER_SIZE - 8); } - this->password = password; - } + auto attr_realm = getAttribute(); + auto attr_nonce = getAttribute(); - void StunPacket::Serialize(uint8_t* buffer) - { - MS_TRACE(); + std::string key = password; + if (attr_nonce && attr_realm) { + // 使用长期凭证机制 + // 根据RFC 5389/5766标准:key = MD5(username ":" realm ":" password) + auto realm = attr_realm->getRealm(); + std::string input = ufrag + ":" + std::string(realm.data(), realm.size()) + ":" + password; + key = openssl_MD5(input.data(), input.size()); - // Some useful variables. - uint16_t usernamePaddedLen{ 0 }; - uint16_t xorMappedAddressPaddedLen{ 0 }; - bool addXorMappedAddress = - ((this->xorMappedAddress != nullptr) && this->method == StunPacket::Method::BINDING && - this->klass == Class::SUCCESS_RESPONSE); - bool addErrorCode = ((this->errorCode != 0u) && this->klass == Class::ERROR_RESPONSE); - bool addMessageIntegrity = (this->klass != Class::ERROR_RESPONSE && !this->password.empty()); - bool addFingerprint{ true }; // Do always. - - // Update data pointer. - this->data = buffer; - - // First calculate the total required size for the entire packet. - this->size = 20; // Header. - - if (!this->username.empty()) - { - usernamePaddedLen = Utils::Byte::PadTo4Bytes(static_cast(this->username.length())); - this->size += 4 + usernamePaddedLen; + // DebugL << "ufrag: " << ufrag; + // DebugL << "realm: " << realm.data(); + // DebugL << "password: " << password; + // DebugL << "input: " << input; } - if (this->priority != 0u) - this->size += 4 + 4; + auto computedMessageIntegrity = openssl_HMACsha1(key.data(), key.size(), _data->data(), _message_integrity_data_len); - if (this->iceControlling != 0u) - this->size += 4 + 8; + // DebugL << "cal MessageIntegrity"; + // DebugL << "password: " << password; + // DebugL << "key: " << toolkit::hexdump(key.data(), key.size()); + // DebugL << "data: " << toolkit::hexdump(_data->data(), _message_integrity_data_len); + // DebugL << "_message_integrity_data_len: " << _message_integrity_data_len; + // DebugL << "_hmac: " << toolkit::hexdump(attr_message_integrity->_hmac.data(), attr_message_integrity->_hmac.size()); + // DebugL << "cal: " << toolkit::hexdump(computedMessageIntegrity.data(), computedMessageIntegrity.size()); - if (this->iceControlled != 0u) - this->size += 4 + 8; + if (attr_message_integrity->getHmac() != computedMessageIntegrity) { + return Authentication::UNAUTHORIZED; + } - if (this->hasUseCandidate) - this->size += 4; + if (hasAttribute(StunAttribute::Type::FINGERPRINT)) { + Byte::Set2Bytes((uint8_t*)_data->data(), 2, _data->size() - HEADER_SIZE); + } + } - if (addXorMappedAddress) - { - switch (this->xorMappedAddress->sa_family) - { - case AF_INET: - { - xorMappedAddressPaddedLen = 8; - this->size += 4 + 8; - - break; - } - - case AF_INET6: - { - xorMappedAddressPaddedLen = 20; - this->size += 4 + 20; - - break; - } - - default: - { - MS_ERROR("invalid inet family in XOR-MAPPED-ADDRESS attribute"); - - addXorMappedAddress = false; - } + // FINGERPRINT验证 + if (hasAttribute(StunAttribute::Type::FINGERPRINT)) { + auto attr_fingerprint = getAttribute(); + if (attr_fingerprint) { + // 计算FINGERPRINT:对除FINGERPRINT属性外的整个包计算CRC32 + uint32_t computedFingerprint = getCRC32((uint8_t*)_data->data(), _data->size() - 8) ^ 0x5354554e; + if (attr_fingerprint->getFingerprint() != computedFingerprint) { + // DebugL << "FINGERPRINT verification failed, expected: " << std::hex << computedFingerprint + // << ", got: " << attr_fingerprint->getFingerprint(); + return Authentication::UNAUTHORIZED; + } else { + // TraceL << "FINGERPRINT verification passed"; } } + } - if (addErrorCode) - this->size += 4 + 4; + return Authentication::OK; +} - if (addMessageIntegrity) - this->size += 4 + 20; +void StunPacket::serialize() { + //TraceL; - if (addFingerprint) - this->size += 4 + 4; + _data = BufferRaw::create(); + for (auto it : _attribute_map) { + it.second->storeToData(); + } - // Merge class and method fields into type. - uint16_t typeField = (static_cast(this->method) & 0x0f80) << 2; + auto attr_size = getAttrSize(); - typeField |= (static_cast(this->method) & 0x0070) << 1; - typeField |= (static_cast(this->method) & 0x000f); - typeField |= (static_cast(this->klass) & 0x02) << 7; - typeField |= (static_cast(this->klass) & 0x01) << 4; + if (getClass() == StunPacket::Class::ERROR_RESPONSE) { + setNeedFingerprint(false); + setNeedMessageIntegrity(false); + } - // Set type field. - Utils::Byte::Set2Bytes(buffer, 0, typeField); - // Set length field. - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size) - 20); - // Set magic cookie. - std::memcpy(buffer + 4, StunPacket::magicCookie, 4); - // Set TransactionId field. - std::memcpy(buffer + 8, this->transactionId, 12); - // Update the transaction ID pointer. - this->transactionId = buffer + 8; - // Add atributes. - size_t pos{ 20 }; + if (getClass() == StunPacket::Class::INDICATION) { + setNeedMessageIntegrity(false); + } - // Add USERNAME. - if (usernamePaddedLen != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USERNAME)); - Utils::Byte::Set2Bytes(buffer, pos + 2, static_cast(this->username.length())); - std::memcpy(buffer + pos + 4, this->username.c_str(), this->username.length()); - pos += 4 + usernamePaddedLen; - } + auto message_integrity_size = getNeedMessageIntegrity() ? 24 : 0; + auto fingerprint_size = getNeedFingerprint() ? 8 : 0; - // Add PRIORITY. - if (this->priority != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::PRIORITY)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - Utils::Byte::Set4Bytes(buffer, pos + 4, this->priority); - pos += 4 + 4; - } + auto packet_size = HEADER_SIZE + attr_size + message_integrity_size + fingerprint_size; + _data->setCapacity(packet_size); + _data->setSize(packet_size); - // Add ICE-CONTROLLING. - if (this->iceControlling != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLING)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 8); - Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlling); - pos += 4 + 8; - } + // Merge class and method fields into type. + uint16_t typeField = (static_cast(_method) & 0x0f80) << 2; - // Add ICE-CONTROLLED. - if (this->iceControlled != 0u) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ICE_CONTROLLED)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 8); - Utils::Byte::Set8Bytes(buffer, pos + 4, this->iceControlled); - pos += 4 + 8; - } + typeField |= (static_cast(_method) & 0x0070) << 1; + typeField |= (static_cast(_method) & 0x000f); + typeField |= (static_cast(_klass) & 0x02) << 7; + typeField |= (static_cast(_klass) & 0x01) << 4; - // Add USE-CANDIDATE. - if (this->hasUseCandidate) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::USE_CANDIDATE)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 0); - pos += 4; - } + // Set type field. + Byte::Set2Bytes((unsigned char *)_data->data(), 0, typeField); + uint16_t initial_length = static_cast(attr_size + message_integrity_size); + Byte::Set2Bytes((unsigned char *)_data->data(), 2, initial_length); + // Set magic cookie. + std::memcpy(_data->data() + 4, StunPacket::_magicCookie, 4); + // Set TransactionId field. + std::memcpy(_data->data() + 8, _transaction_id.data(), 12); - // Add XOR-MAPPED-ADDRESS - if (addXorMappedAddress) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::XOR_MAPPED_ADDRESS)); - Utils::Byte::Set2Bytes(buffer, pos + 2, xorMappedAddressPaddedLen); - - uint8_t* attrValue = buffer + pos + 4; - - switch (this->xorMappedAddress->sa_family) - { - case AF_INET: - { - // Set first byte to 0. - attrValue[0] = 0; - // Set inet family. - attrValue[1] = 0x01; - // Set port and XOR it. - std::memcpy( - attrValue + 2, - &(reinterpret_cast(this->xorMappedAddress))->sin_port, - 2); - attrValue[2] ^= StunPacket::magicCookie[0]; - attrValue[3] ^= StunPacket::magicCookie[1]; - // Set address and XOR it. - std::memcpy( - attrValue + 4, - &(reinterpret_cast(this->xorMappedAddress))->sin_addr.s_addr, - 4); - attrValue[4] ^= StunPacket::magicCookie[0]; - attrValue[5] ^= StunPacket::magicCookie[1]; - attrValue[6] ^= StunPacket::magicCookie[2]; - attrValue[7] ^= StunPacket::magicCookie[3]; - - pos += 4 + 8; - - break; - } - - case AF_INET6: - { - // Set first byte to 0. - attrValue[0] = 0; - // Set inet family. - attrValue[1] = 0x02; - // Set port and XOR it. - std::memcpy( - attrValue + 2, - &(reinterpret_cast(this->xorMappedAddress))->sin6_port, - 2); - attrValue[2] ^= StunPacket::magicCookie[0]; - attrValue[3] ^= StunPacket::magicCookie[1]; - // Set address and XOR it. - std::memcpy( - attrValue + 4, - &(reinterpret_cast(this->xorMappedAddress))->sin6_addr.s6_addr, - 16); - attrValue[4] ^= StunPacket::magicCookie[0]; - attrValue[5] ^= StunPacket::magicCookie[1]; - attrValue[6] ^= StunPacket::magicCookie[2]; - attrValue[7] ^= StunPacket::magicCookie[3]; - attrValue[8] ^= this->transactionId[0]; - attrValue[9] ^= this->transactionId[1]; - attrValue[10] ^= this->transactionId[2]; - attrValue[11] ^= this->transactionId[3]; - attrValue[12] ^= this->transactionId[4]; - attrValue[13] ^= this->transactionId[5]; - attrValue[14] ^= this->transactionId[6]; - attrValue[15] ^= this->transactionId[7]; - attrValue[16] ^= this->transactionId[8]; - attrValue[17] ^= this->transactionId[9]; - attrValue[18] ^= this->transactionId[10]; - attrValue[19] ^= this->transactionId[11]; - - pos += 4 + 20; - - break; - } - } - } - - // Add ERROR-CODE. - if (addErrorCode) - { - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::ERROR_CODE)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - - auto codeClass = static_cast(this->errorCode / 100); - uint8_t codeNumber = static_cast(this->errorCode) - (codeClass * 100); - - Utils::Byte::Set2Bytes(buffer, pos + 4, 0); - Utils::Byte::Set1Byte(buffer, pos + 6, codeClass); - Utils::Byte::Set1Byte(buffer, pos + 7, codeNumber); - pos += 4 + 4; + storeAttrMessage(); + if (message_integrity_size) { + auto ufrag = _peer_ufrag; + auto password = _peer_password; + if (getClass() == StunPacket::Class::SUCCESS_RESPONSE || + getClass() == StunPacket::Class::ERROR_RESPONSE) { + ufrag = _ufrag; + password = _password; } // Add MESSAGE-INTEGRITY. - if (addMessageIntegrity) - { - // Ignore FINGERPRINT. - if (addFingerprint) - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20 - 8)); + auto attr_nonce = getAttribute(); + auto attr_realm = getAttribute(); + // FIXME: need use SASLprep(password) replace password + // 根据RFC 5766标准:key = MD5(username ":" realm ":" SASLprep(password)) + std::string key = password; + if (attr_nonce && attr_realm) { + // 使用长期凭证机制 + // key = MD5(username ":" realm ":" password) + auto realm = attr_realm->getRealm(); + std::string username = ufrag; // 对于response消息,使用ufrag作为username + std::string input = username + ":" + std::string(realm.data(), realm.size()) + ":" + password; + key = openssl_MD5(input.data(), input.size()); - // Calculate the HMAC-SHA1 of the packet according to MESSAGE-INTEGRITY rules. - auto computedMessageIntegrity = openssl_HMACsha1(this->password.data(), this->password.size(), buffer, pos); - - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::MESSAGE_INTEGRITY)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 20); - std::memcpy(buffer + pos + 4, computedMessageIntegrity.data(), computedMessageIntegrity.size()); - - // Update the pointer. - this->messageIntegrity = buffer + pos + 4; - pos += 4 + 20; - - // Restore length field. - if (addFingerprint) - Utils::Byte::Set2Bytes(buffer, 2, static_cast(this->size - 20)); - } - else - { - // Unset the pointer (if it was set). - this->messageIntegrity = nullptr; + // DebugL << "Long-term credential used for response:"; + // DebugL << "ufrag: " << ufrag; + // DebugL << "realm: " << std::string(realm.data(), realm.size()); + // DebugL << "password: " << password; + // DebugL << "input: " << input; + // DebugL << "MD5 key: " << toolkit::hexdump(key.data(), key.size()); } - // Add FINGERPRINT. - if (addFingerprint) - { - // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT - // attribute and XOR it with 0x5354554e. - uint32_t computedFingerprint = GetCRC32(buffer, pos) ^ 0x5354554e; + size_t mi_calc_len = HEADER_SIZE + attr_size; + auto computedMessageIntegrity = openssl_HMACsha1(key.data(), key.size(), _data->data(), mi_calc_len); + auto attr_message_integrity = std::make_shared(); + attr_message_integrity->setHmac(computedMessageIntegrity); + attr_message_integrity->storeToData(); + memcpy((unsigned char *)_data->data() + HEADER_SIZE + attr_size, attr_message_integrity->data(), attr_message_integrity->size()); - Utils::Byte::Set2Bytes(buffer, pos, static_cast(Attribute::FINGERPRINT)); - Utils::Byte::Set2Bytes(buffer, pos + 2, 4); - Utils::Byte::Set4Bytes(buffer, pos + 4, computedFingerprint); - pos += 4 + 4; - - // Set flag. - this->hasFingerprint = true; - } - else - { - this->hasFingerprint = false; - } - - MS_ASSERT(pos == this->size, "pos != this->size"); + // DebugL << "Serialize MESSAGE-INTEGRITY:"; + // DebugL << "password: \"" << password << "\""; + // DebugL << "key: " << toolkit::hexdump(key.data(), key.size()); + // DebugL << "hmac_calculated: " << toolkit::hexdump(computedMessageIntegrity.data(), computedMessageIntegrity.size()); } + + if (fingerprint_size) { + // Add FINGERPRINT. + // Compute the CRC32 of the packet up to (but excluding) the FINGERPRINT + uint16_t final_length = static_cast(attr_size + message_integrity_size + fingerprint_size); + Byte::Set2Bytes((unsigned char *)_data->data(), 2, final_length); + size_t fp_calc_len = HEADER_SIZE + attr_size + message_integrity_size; + uint32_t computedFingerprint = getCRC32((unsigned char *)_data->data(), fp_calc_len) ^ 0x5354554e; + + auto attr_fingerprint = std::make_shared(); + attr_fingerprint->setFingerprint(computedFingerprint); + attr_fingerprint->storeToData(); + memcpy((unsigned char *)_data->data() + HEADER_SIZE + attr_size + message_integrity_size, attr_fingerprint->data(), attr_fingerprint->size()); + } +} + +StunPacket::Ptr StunPacket::createSuccessResponse() const { + // TraceL; + CHECK(_klass == Class::REQUEST, "attempt to create a success response for a non Request STUN packet"); + + auto packet = std::make_shared(Class::SUCCESS_RESPONSE, _method, _transaction_id.c_str()); + + // 复制认证相关属性到响应包中,用于MESSAGE-INTEGRITY计算 + auto attr_realm = getAttribute(StunAttribute::Type::REALM); + if (attr_realm) { + packet->addAttribute(attr_realm); + } + + auto attr_nonce = getAttribute(StunAttribute::Type::NONCE); + if (attr_nonce) { + packet->addAttribute(attr_nonce); + DebugL << "Copied NONCE attribute to response"; + } + + return packet; +} + +StunPacket::Ptr StunPacket::createErrorResponse(StunAttrErrorCode::Code errorCode) const { + TraceL; + CHECK(_klass == Class::REQUEST, "attempt to create an error response for a non Request STUN packet"); + auto ret = std::make_shared(Class::ERROR_RESPONSE, _method, _transaction_id.c_str()); + auto attr = std::make_shared(); + attr->setErrorCode(errorCode); + ret->addAttribute(std::move(attr)); + return ret; +} + +char *StunPacket::data() const { + return _data ? _data->data() : nullptr; +} + +size_t StunPacket::size() const { + return _data ? _data->size() : 0; +} + +bool StunPacket::loadFromData(const uint8_t *buf, size_t len) { + if (HEADER_SIZE > len) { + WarnL << "size too small " << len; + return false; + } + + _data = BufferRaw::create(); + _data->assign((const char *)(buf), len); + + _transaction_id.assign((const char *)buf + 8, 12); + + if (len == HEADER_SIZE) { + return true; + } + + return loadAttrMessage(buf + HEADER_SIZE, len - HEADER_SIZE); +} + +bool StunPacket::loadAttrMessage(const uint8_t *buf, size_t len) { + _attribute_map.clear(); + _message_integrity_data_len = HEADER_SIZE + len; + + uint8_t *ptr = const_cast(buf); + StunAttribute::Ptr attr = nullptr; + while (ptr < buf + len) { + auto type = (StunAttribute::Type)Byte::Get2Bytes(ptr, 0); + size_t length = Byte::Get2Bytes(ptr, 2); + size_t lengthAlign = Byte::PadTo4Bytes((uint16_t)length); + + switch (type) { + case StunAttribute::Type::MAPPED_ADDRESS: attr = std::make_shared(); break; + case StunAttribute::Type::USERNAME: attr = std::make_shared(); break; + case StunAttribute::Type::MESSAGE_INTEGRITY: + attr = std::make_shared(); + _message_integrity_data_len = HEADER_SIZE + ptr - buf; + break; + case StunAttribute::Type::ERROR_CODE: attr = std::make_shared(); break; + case StunAttribute::Type::CHANNEL_NUMBER: attr = std::make_shared(); break; + case StunAttribute::Type::LIFETIME: attr = std::make_shared(); break; + case StunAttribute::Type::DATA: attr = std::make_shared(); break; + case StunAttribute::Type::REALM: attr = std::make_shared(); break; + case StunAttribute::Type::NONCE: attr = std::make_shared(); break; + case StunAttribute::Type::REQUESTED_TRANSPORT: attr = std::make_shared(); break; + case StunAttribute::Type::XOR_PEER_ADDRESS: attr = std::make_shared(_transaction_id); break; + case StunAttribute::Type::XOR_RELAYED_ADDRESS: attr = std::make_shared(_transaction_id); break; + case StunAttribute::Type::XOR_MAPPED_ADDRESS: attr = std::make_shared(_transaction_id); break; + + case StunAttribute::Type::PRIORITY: attr = std::make_shared(); break; + case StunAttribute::Type::USE_CANDIDATE: attr = std::make_shared(); break; + case StunAttribute::Type::FINGERPRINT: attr = std::make_shared(); break; + case StunAttribute::Type::ICE_CONTROLLED: attr = std::make_shared(); break; + case StunAttribute::Type::ICE_CONTROLLING: attr = std::make_shared(); break; + case StunAttribute::Type::GOOG_NETWORK_INFO: + case StunAttribute::Type::SOFTWARE: + break; + default: WarnL << "not support Attribute " << (uint16_t)type << "," << toolkit::hexdump(ptr, 2); break; + } + + if (attr) { + if (ptr + lengthAlign + 4 > buf + len) { + WarnL << "the attribute length exceeds the remaining size, packet discarded"; + return false; + } + + if (attr->loadFromData(ptr, StunAttribute::ATTR_HEADER_SIZE + length)) { + _attribute_map.emplace(type, std::move(attr)); + + } else { + if (StunAttribute::isComprehensionRequired(ptr, 4)) { + WarnL << "parse a Comprehension Required Stun Attribute failed, type=" << (uint16_t)type << " len=" << length; + return false; + } + WarnL << "parse Stun Attribute failed type=" << (uint16_t)type << " len=" << length; + } + attr = nullptr; + } + + ptr += lengthAlign + StunAttribute::ATTR_HEADER_SIZE; + } + return true; +} + +bool StunPacket::storeAttrMessage() { + uint8_t *buf = (uint8_t *)_data->data() + HEADER_SIZE; + for (auto &pr : _attribute_map) { + memcpy(buf, pr.second->data(), pr.second->size()); + buf += pr.second->size(); + } + return true; +} + +size_t StunPacket::getAttrSize() const { + size_t size = 0; + for (auto &pr : _attribute_map) { + size += pr.second->size(); + } + return size; +} + +SuccessResponsePacket::SuccessResponsePacket(Method method, const std::string& transaction_id) : + StunPacket(Class::SUCCESS_RESPONSE, method, transaction_id.c_str()) { +} + +ErrorResponsePacket::ErrorResponsePacket(Method method, const std::string& transaction_id, StunAttrErrorCode::Code error_code) : + StunPacket(Class::ERROR_RESPONSE, method, transaction_id.c_str()) { + DebugL; + auto attr = std::make_shared(); + attr->setErrorCode(error_code); + addAttribute(std::move(attr)); +} + } // namespace RTC diff --git a/webrtc/StunPacket.hpp b/webrtc/StunPacket.hpp index 2776a9b6..294ee4a5 100644 --- a/webrtc/StunPacket.hpp +++ b/webrtc/StunPacket.hpp @@ -1,213 +1,689 @@ -/** -ISC License +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. +*/ -Copyright © 2015, Iñaki Baz Castillo +#ifndef ZLMEDIAKIT_WEBRTC_STUN_PACKET_HPP +#define ZLMEDIAKIT_WEBRTC_STUN_PACKET_HPP -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef MS_RTC_STUN_PACKET_HPP -#define MS_RTC_STUN_PACKET_HPP - - -#include "logger.h" -#include "Utils.hpp" #include +#include "Util/Byte.hpp" +#include "Network/Buffer.h" +#include "Network/sockutil.h" -namespace RTC -{ - class StunPacket - { - public: - // STUN message class. - enum class Class : uint16_t - { - REQUEST = 0, - INDICATION = 1, - SUCCESS_RESPONSE = 2, - ERROR_RESPONSE = 3 - }; +namespace RTC { +// reference https://rcf-editor.org/rfc/rfc8489 +// reference https://rcf-editor.org/rfc/rfc8656 +// reference https://rcf-editor.org/rfc/rfc8445 - // STUN message method. - enum class Method : uint16_t - { - BINDING = 1 - }; +//////////// Attribute ////////////////////////// +// reference https://rcf-editor.org/rfc/rfc8489 +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Type | Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Value (variable) .... ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 4: Format of STUN Attributes + reference https://www.rfc-editor.org/rfc/rfc8489.html#section-14 +*/ +class StunAttribute { +public: + // Attribute type. + enum class Type : uint16_t { + MAPPED_ADDRESS = 0x0001, + RESPONSE_ADDRESS = 0x0002, // Reserved; was RESPONSE-ADDRESS prior to [RFC5389] + CHANGE_REQUEST = 0x0003, // Reserved; was CHANGE-REQUEST prior to [RFC5389] + CHANGED_ADDRESS = 0x0005, // Reserved; was CHANGED-ADDRESS prior to [RFC5389] + USERNAME = 0x0006, + PASSWORD = 0x0005, // Reserved; was PASSWORD prior to [RFC5389] + MESSAGE_INTEGRITY = 0x0008, + ERROR_CODE = 0x0009, + UNKNOWN_ATTRIBUTES = 0x000A, + REFLECTED_FROM = 0x000B, // Reserved; was REFLECTED-FROM prior to [RFC5389] + CHANNEL_NUMBER = 0x000C, // [RFC5766] + LIFETIME = 0x000D, // [RFC5766] + BANDWIDTH = 0x0010, // Reserved; [RFC5766] + XOR_PEER_ADDRESS = 0x0012, // [RFC5766] + DATA = 0x0013, // [RFC5766] + REALM = 0x0014, + NONCE = 0x0015, + XOR_RELAYED_ADDRESS = 0x0016, // [RFC5766] + EVEN_PORT = 0x0018, // [RFC5766] + REQUESTED_TRANSPORT = 0x0019, // [RFC5766] + DONT_FRAGMENT = 0x001A, // [RFC5766] + MESSAGE_INTEGRITY_SHA256 = 0x001C, + USERHASH = 0x001E, + PASSWORD_ALGORITHM = 0x001D, + XOR_MAPPED_ADDRESS = 0x0020, + TIMER_VAL = 0x0021, // Reserved; [RFC5766] + RESERVATION_TOKEN = 0x0022, // [RFC5766] + PRIORITY = 0x0024, + USE_CANDIDATE = 0x0025, - // Attribute type. - enum class Attribute : uint16_t - { - MAPPED_ADDRESS = 0x0001, - USERNAME = 0x0006, - MESSAGE_INTEGRITY = 0x0008, - ERROR_CODE = 0x0009, - UNKNOWN_ATTRIBUTES = 0x000A, - REALM = 0x0014, - NONCE = 0x0015, - XOR_MAPPED_ADDRESS = 0x0020, - PRIORITY = 0x0024, - USE_CANDIDATE = 0x0025, - SOFTWARE = 0x8022, - ALTERNATE_SERVER = 0x8023, - FINGERPRINT = 0x8028, - ICE_CONTROLLED = 0x8029, - ICE_CONTROLLING = 0x802A - }; - - // Authentication result. - enum class Authentication - { - OK = 0, - UNAUTHORIZED = 1, - BAD_REQUEST = 2 - }; - - public: - static bool IsStun(const uint8_t* data, size_t len) - { - // clang-format off - return ( - // STUN headers are 20 bytes. - (len >= 20) && - // DOC: https://tools.ietf.org/html/draft-ietf-avtcore-rfc5764-mux-fixes - (data[0] < 3) && - // Magic cookie must match. - (data[4] == StunPacket::magicCookie[0]) && (data[5] == StunPacket::magicCookie[1]) && - (data[6] == StunPacket::magicCookie[2]) && (data[7] == StunPacket::magicCookie[3]) - ); - // clang-format on - } - static StunPacket* Parse(const uint8_t* data, size_t len); - - private: - static const uint8_t magicCookie[]; - - public: - StunPacket( - Class klass, Method method, const uint8_t* transactionId, const uint8_t* data, size_t size); - ~StunPacket(); - - void Dump() const; - Class GetClass() const - { - return this->klass; - } - Method GetMethod() const - { - return this->method; - } - const uint8_t* GetData() const - { - return this->data; - } - size_t GetSize() const - { - return this->size; - } - void SetUsername(const char* username, size_t len) - { - this->username.assign(username, len); - } - void SetPriority(uint32_t priority) - { - this->priority = priority; - } - void SetIceControlling(uint64_t iceControlling) - { - this->iceControlling = iceControlling; - } - void SetIceControlled(uint64_t iceControlled) - { - this->iceControlled = iceControlled; - } - void SetUseCandidate() - { - this->hasUseCandidate = true; - } - void SetXorMappedAddress(const struct sockaddr* xorMappedAddress) - { - this->xorMappedAddress = xorMappedAddress; - } - void SetErrorCode(uint16_t errorCode) - { - this->errorCode = errorCode; - } - void SetMessageIntegrity(const uint8_t* messageIntegrity) - { - this->messageIntegrity = messageIntegrity; - } - void SetFingerprint() - { - this->hasFingerprint = true; - } - const std::string& GetUsername() const - { - return this->username; - } - uint32_t GetPriority() const - { - return this->priority; - } - uint64_t GetIceControlling() const - { - return this->iceControlling; - } - uint64_t GetIceControlled() const - { - return this->iceControlled; - } - bool HasUseCandidate() const - { - return this->hasUseCandidate; - } - uint16_t GetErrorCode() const - { - return this->errorCode; - } - bool HasMessageIntegrity() const - { - return (this->messageIntegrity ? true : false); - } - bool HasFingerprint() const - { - return this->hasFingerprint; - } - Authentication CheckAuthentication( - const std::string& localUsername, const std::string& localPassword); - StunPacket* CreateSuccessResponse(); - StunPacket* CreateErrorResponse(uint16_t errorCode); - void Authenticate(const std::string& password); - void Serialize(uint8_t* buffer); - - private: - // Passed by argument. - Class klass; // 2 bytes. - Method method; // 2 bytes. - const uint8_t* transactionId{ nullptr }; // 12 bytes. - uint8_t* data{ nullptr }; // Pointer to binary data. - size_t size{ 0u }; // The full message size (including header). - // STUN attributes. - std::string username; // Less than 513 bytes. - uint32_t priority{ 0u }; // 4 bytes unsigned integer. - uint64_t iceControlling{ 0u }; // 8 bytes unsigned integer. - uint64_t iceControlled{ 0u }; // 8 bytes unsigned integer. - bool hasUseCandidate{ false }; // 0 bytes. - const uint8_t* messageIntegrity{ nullptr }; // 20 bytes. - bool hasFingerprint{ false }; // 4 bytes. - const struct sockaddr* xorMappedAddress{ nullptr }; // 8 or 20 bytes. - uint16_t errorCode{ 0u }; // 4 bytes (no reason phrase). - std::string password; + //Comprehension-optional range (0x8000-0xFFFF) + PASSWORD_ALGORITHMS = 0x8002, + ALTERNATE_DOMAIN = 0x8003, + SOFTWARE = 0x8022, + ALTERNATE_SERVER = 0x8023, + FINGERPRINT = 0x8028, + ICE_CONTROLLED = 0x8029, + ICE_CONTROLLING = 0x802A, + GOOG_NETWORK_INFO = 0xC057, }; + + static const size_t ATTR_HEADER_SIZE = 4; + static bool isComprehensionRequired(const uint8_t *data, size_t len); + + using Ptr = std::shared_ptr; + StunAttribute(StunAttribute::Type type) : _type(type) {} + virtual ~StunAttribute() = default; + + char *data() { return _data ? _data->data() : nullptr; } + char *body() { return _data ? _data->data() + ATTR_HEADER_SIZE : nullptr; } + size_t size() const { return _data ? _data->size() : 0; } + + Type type() const { return _type; } + + virtual bool loadFromData(const uint8_t *buf, size_t len) = 0; + virtual bool storeToData() = 0; + // virtual std::string dump() = 0; + +protected: + const uint8_t * loadHeader(const uint8_t *buf); + uint8_t * storeHeader(); + +protected: + Type _type; + uint16_t _length; + toolkit::BufferRaw::Ptr _data; +}; + +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0 0 0 0 0 0 0 0| Family | Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Address (32 bits or 128 bits) | +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 5: Format of MAPPED-ADDRESS Attribute + reference https://www.rfc-editor.org/rfc/rfc8489.html#page-37 +*/ +class StunAttrMappedAddress : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::MAPPED_ADDRESS; + StunAttrMappedAddress() : StunAttribute(TYPE) {}; + virtual ~StunAttrMappedAddress() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; +}; + +class StunAttrUserName : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::USERNAME; + StunAttrUserName() : StunAttribute(TYPE) {}; + virtual ~StunAttrUserName() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setUsername(std::string username) { _username = std::move(username); } + + const std::string& getUsername() const { return _username; } + +private: + std::string _username; +}; + +class StunAttrMessageIntegrity : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::MESSAGE_INTEGRITY; + StunAttrMessageIntegrity() : StunAttribute(TYPE) {}; + virtual ~StunAttrMessageIntegrity() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setHmac(std::string hmac) { _hmac = std::move(hmac); } + const std::string &getHmac() const { return _hmac; } +private: + std::string _hmac; +}; + +class StunAttrErrorCode : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::ERROR_CODE; + StunAttrErrorCode() : StunAttribute(TYPE) {}; + virtual ~StunAttrErrorCode() = default; + + enum class Code : uint16_t { + Invalid = 0, // + TryAlternate = 300, //尝试备用服务器 + BadRequest = 400, + Unauthorized = 401, + Forbidden = 403, //禁止 + RequestTimedOut = 408, //请求超时(客户端认为此事务已经失败) + UnknownAttribute = 420, + AllocationMismatch = 438, + StaleNonce = 438, //NONCE 不再有效,客户端应使用响应中的NONCE重试 + AddressFamilyNotSupported = 440, //不支持的协议簇 + WrongCredentials = 441, //凭据错误 + UnsupportedTransportAddress = 442, //不支持的传输地址 + AllocationQuotaReached = 486, //alloction 达到上限,客户端应该至少等待一分钟后重新尝试创建 + RoleConflict = 487, //角色冲突 + ServerError = 500, //服务器临时错误,客户端应重试 + InsuficientCapacity = 508, //容量不足,没有更多可用的中继传输地址 + }; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setErrorCode(Code error_code) { _error_code = error_code; } + Code getErrorCode() const { return _error_code; } +private: + Code _error_code; +}; + +class StunAttrChannelNumber : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::CHANNEL_NUMBER; + StunAttrChannelNumber() : StunAttribute(TYPE) {}; + virtual ~StunAttrChannelNumber() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + + void setChannelNumber(uint16_t channel_number) { _channel_number = channel_number; } + uint16_t getChannelNumber() const { return _channel_number; } +private: + uint16_t _channel_number; +}; + +class StunAttrLifeTime : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::LIFETIME; + StunAttrLifeTime() : StunAttribute(TYPE) {}; + ~StunAttrLifeTime() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setLifetime(uint32_t lifetime) { _lifetime = lifetime; } + uint32_t getLifetime() const { return _lifetime; } +private: + uint32_t _lifetime; +}; + +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0 0 0 0 0 0 0 0| Family | X-Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| X-Address (Variable) ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 6: Format of XOR-MAPPED-ADDRESS Attribute + reference https://www.rfc-editor.org/rfc/rfc8489.html#page-38 +*/ +class StunAttrXorPeerAddress : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::XOR_PEER_ADDRESS; + StunAttrXorPeerAddress(std::string transaction_id) + : StunAttribute(TYPE) + , _transaction_id(std::move(transaction_id)) {} + virtual ~StunAttrXorPeerAddress() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setAddr(const struct sockaddr_storage &addr) { _addr = addr; } + const struct sockaddr_storage& getAddr() const { return _addr; } + + std::string getIp() const { return toolkit::SockUtil::inet_ntoa((struct sockaddr *)&_addr); } + uint16_t getPort() const { return toolkit::SockUtil::inet_port((struct sockaddr *)&_addr); } + +protected: + struct sockaddr_storage _addr; + std::string _transaction_id; +}; + +class StunAttrData : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::DATA; + StunAttrData() : StunAttribute(TYPE) {}; + virtual ~StunAttrData() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + + void setData(std::string data) { _data_content = std::move(data); } + void setData(const char *data, int size) { _data_content.assign(data, size); } + const std::string &getData() const { return _data_content; } + +private: + std::string _data_content; +}; + +class StunAttrRealm : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::REALM; + StunAttrRealm() : StunAttribute(TYPE) {}; + virtual ~StunAttrRealm() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setRealm(std::string realm) { _realm = std::move(realm); } + const std::string &getRealm() const { return _realm; } +private: + // 长度小于128字符 + std::string _realm; +}; + +class StunAttrNonce : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::NONCE; + StunAttrNonce() : StunAttribute(TYPE) {}; + virtual ~StunAttrNonce() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setNonce(std::string nonce) { _nonce = std::move(nonce); } + const std::string& getNonce() const { return _nonce; } +private: + // 长度小于128字符 + std::string _nonce; +}; + +class StunAttrXorRelayedAddress : public StunAttrXorPeerAddress { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::XOR_RELAYED_ADDRESS; + StunAttrXorRelayedAddress(std::string transaction_id) : StunAttrXorPeerAddress(std::move(transaction_id)) { + _type = TYPE; + } + virtual ~StunAttrXorRelayedAddress() = default; +}; + +class StunAttrXorMappedAddress : public StunAttrXorPeerAddress { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::XOR_MAPPED_ADDRESS; + StunAttrXorMappedAddress(std::string transaction_id) : StunAttrXorPeerAddress(std::move(transaction_id)) { + _type = TYPE; + } + virtual ~StunAttrXorMappedAddress() = default; +}; + +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Protocol | RFFU | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + reference https://www.rfc-editor.org/rfc/rfc5766.html#section-14.7 +*/ +class StunAttrRequestedTransport : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::REQUESTED_TRANSPORT; + StunAttrRequestedTransport() : StunAttribute(TYPE) {}; + virtual ~StunAttrRequestedTransport() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + enum class Protocol : uint8_t { + // This specification only allows the use of codepoint 17 (User Datagram Protocol). + UDP = 0x11, + }; + + void setProtocol(Protocol protocol) { _protocol = protocol; } + Protocol getProtocol() const { return _protocol; } +private: + Protocol _protocol = Protocol::UDP; +}; + +class StunAttrPriority : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::PRIORITY; + StunAttrPriority() : StunAttribute(TYPE) {}; + virtual ~StunAttrPriority() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setPriority(uint64_t priority) { _priority = priority; } + uint64_t getPriority() const { return _priority; } +private: + uint32_t _priority; +}; + +class StunAttrUseCandidate : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::USE_CANDIDATE; + StunAttrUseCandidate() : StunAttribute(TYPE) {}; + virtual ~StunAttrUseCandidate() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; +}; + +class StunAttrFingerprint : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::FINGERPRINT; + StunAttrFingerprint() : StunAttribute(TYPE) {}; + virtual ~StunAttrFingerprint() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setFingerprint(uint32_t fingerprint) { _fingerprint = fingerprint; } + uint32_t getFingerprint() const { return _fingerprint; } +private: + uint32_t _fingerprint; +}; + +class StunAttrIceControlled : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::ICE_CONTROLLED; + StunAttrIceControlled() : StunAttribute(TYPE) {}; + virtual ~StunAttrIceControlled() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setTiebreaker(uint64_t tiebreaker) { _tiebreaker = tiebreaker; } + uint64_t getTiebreaker() const { return _tiebreaker; } +private: + uint64_t _tiebreaker = 0; // 8 bytes unsigned integer. +}; + +class StunAttrIceControlling : public StunAttribute { +public: + using Ptr = std::shared_ptr; + static constexpr Type TYPE = StunAttribute::Type::ICE_CONTROLLING; + StunAttrIceControlling() : StunAttribute(TYPE) {}; + virtual ~StunAttrIceControlling() = default; + + bool loadFromData(const uint8_t *buf, size_t len) override; + bool storeToData() override; + // std::string dump() override; + + void setTiebreaker(uint64_t tiebreaker) { _tiebreaker = tiebreaker; } + uint64_t getTiebreaker() const { return _tiebreaker; } +private: + uint64_t _tiebreaker = 0; // 8 bytes unsigned integer. +}; + +//////////// STUN ////////////////////////// +/* +0 1 2 3 +0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +|0 0| STUN Message Type | Message Length | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Magic Cookie | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| | +| Transaction ID (96 bits) | +| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Figure 2: Format of STUN Message Header + reference https://www.rfc-editor.org/rfc/rfc8489.html#section-5 */ +class StunPacket : public toolkit::Buffer { +public: + using Ptr = std::shared_ptr; + + // STUN message class. + enum class Class : uint8_t { + REQUEST = 0, + INDICATION = 1, + SUCCESS_RESPONSE = 2, + ERROR_RESPONSE = 3 + }; + + // STUN message method. + enum class Method : uint16_t { + BINDING = 0x001, + + //TURN Extended + //https://www.rfc-editor.org/rfc/rfc5766.html#section-13 + ALLOCATE = 0x003, // (only request/response semantics defined) + REFRESH = 0x004, // (only request/response semantics defined) + SEND = 0x006, // (only indication semantics defined) + DATA = 0x007, // (only indication semantics defined) + CREATEPERMISSION = 0x008, // (only request/response semantics defined + CHANNELBIND = 0x009, // (only request/response semantics defined) + }; + + // Authentication result. + enum class Authentication { + OK = 0, + UNAUTHORIZED = 1, + BAD_REQUEST = 2 + }; + + struct EnumClassHash { + template + std::size_t operator()(T t) const { + return static_cast(t); + } + }; + struct ClassMethodHash { + bool operator()(std::pair key) const { + std::size_t h = 0; + h ^= std::hash()((uint8_t)key.first) << 1; + h ^= std::hash()((uint8_t)key.second) << 2; + return h; + } + }; + + static const size_t HEADER_SIZE = 20; + static const uint8_t _magicCookie[]; + + static bool isStun(const uint8_t *data, size_t len); + static Class getClass(const uint8_t *data, size_t len); + static Method getMethod(const uint8_t *data, size_t len); + static StunPacket::Ptr parse(const uint8_t *data, size_t len); + static std::string mappingClassEnum2Str(Class klass); + static std::string mappingMethodEnum2Str(Method method); + + StunPacket(Class klass, Method method, const char* transId = nullptr); + virtual ~StunPacket(); + + Class getClass() const { return _klass; } + + Method getMethod() const { return _method; } + + std::string getClassStr() const { return StrPrinter << mappingClassEnum2Str(_klass) << "(" << (uint32_t)_klass << ")"; } + + std::string getMethodStr() const { return StrPrinter << mappingMethodEnum2Str(_method) << "(" << (uint32_t)_method << ")"; } + + std::string dumpString(bool transId = false) const; + + const std::string& getTransactionId() const { return _transaction_id; } + + void setUfrag(std::string ufrag) { _ufrag = std::move(ufrag); } + const std::string& getUfrag() const { return _ufrag; } + + void setPassword(std::string password) { _password = std::move(password); } + const std::string& getPassword() const { return _password; } + + void setPeerUfrag(std::string peer_ufrag) { _peer_ufrag = std::move(peer_ufrag); } + const std::string& getPeerUfrag() const { return _peer_ufrag; } + + void setPeerPassword(std::string peer_password) { _peer_password = std::move(peer_password); } + const std::string& getPeerPassword() const { return _peer_password; } + + void setNeedMessageIntegrity(bool flag) { _need_message_integrity = flag; } + bool getNeedMessageIntegrity() const { return _need_message_integrity; } + + void setNeedFingerprint(bool flag) { _need_fingerprint = flag; } + bool getNeedFingerprint() const { return _need_fingerprint; } + + void refreshTransactionId() { _transaction_id = toolkit::makeRandStr(12, false); } + + void addAttribute(StunAttribute::Ptr attr); + void removeAttribute(StunAttribute::Type type); + bool hasAttribute(StunAttribute::Type type) const; + StunAttribute::Ptr getAttribute(StunAttribute::Type type) const; + + template + std::shared_ptr getAttribute() const { + auto attr = getAttribute(T::TYPE); + if (attr) { + return std::dynamic_pointer_cast(attr); + } + return nullptr; + } + + std::string getUsername() const; + uint64_t getPriority() const; + StunAttrErrorCode::Code getErrorCode() const; + + Authentication checkAuthentication(const std::string &ufrag, const std::string &password) const; + void serialize(); + + StunPacket::Ptr createSuccessResponse() const; + StunPacket::Ptr createErrorResponse(StunAttrErrorCode::Code errorCode) const; + + ///////Buffer override/////// + char *data() const override; + size_t size() const override; + +private: + bool loadFromData(const uint8_t *buf, size_t len); + + // attribute + bool loadAttrMessage(const uint8_t *buf, size_t len); + bool storeAttrMessage(); + size_t getAttrSize() const; + +protected: + + Class _klass; + Method _method; + std::string _transaction_id; // 12 bytes/96bits. + std::map _attribute_map; + toolkit::BufferRaw::Ptr _data; + std::string _ufrag; + std::string _password; + std::string _peer_ufrag; + std::string _peer_password; + size_t _message_integrity_data_len = 0; //MESSAGE_INTEGRITY属性之前的字段 + + bool _need_message_integrity = true; + bool _need_fingerprint = true; +}; + +class BindingPacket : public StunPacket { +public: + BindingPacket() : StunPacket(Class::REQUEST, Method::BINDING) {}; + virtual ~BindingPacket() {}; +}; + +class SuccessResponsePacket : public StunPacket { +public: + SuccessResponsePacket(Method method, const std::string& transaction_id); + virtual ~SuccessResponsePacket() {}; +}; + +class ErrorResponsePacket : public StunPacket { +public: + ErrorResponsePacket(Method method, const std::string& transaction_id, StunAttrErrorCode::Code error_code); + virtual ~ErrorResponsePacket() {}; +}; + +//////////// TURN ////////////////////////// + +class TurnPacket : public StunPacket { +public: + TurnPacket(Class klass, Method method) : StunPacket(klass, method) {} + virtual ~TurnPacket() {}; +}; + +class AllocatePacket : public TurnPacket { +public: + AllocatePacket() : TurnPacket(Class::REQUEST, Method::ALLOCATE) {}; + virtual ~AllocatePacket() {}; +}; + +class RefreshPacket : public TurnPacket { +public: + RefreshPacket() : TurnPacket(Class::REQUEST, Method::REFRESH) {}; + virtual ~RefreshPacket() {}; +}; + +class CreatePermissionPacket : public TurnPacket { +public: + CreatePermissionPacket() : TurnPacket(Class::REQUEST, Method::CREATEPERMISSION) {}; + virtual ~CreatePermissionPacket() {}; +}; + +class ChannelBindPacket : public TurnPacket { +public: + ChannelBindPacket() : TurnPacket(Class::REQUEST, Method::CHANNELBIND) {}; + virtual ~ChannelBindPacket() {}; +}; + +class SendIndicationPacket : public TurnPacket { +public: + SendIndicationPacket() : TurnPacket(Class::INDICATION, Method::SEND) {}; + virtual ~SendIndicationPacket() {}; +}; + +class DataIndicationPacket : public TurnPacket { +public: + DataIndicationPacket() : TurnPacket(Class::INDICATION, Method::DATA) {}; + virtual ~DataIndicationPacket() {}; +}; + +class DataPacket : public TurnPacket { +public: + DataPacket() : TurnPacket(Class::INDICATION, Method::DATA) {}; + virtual ~DataPacket() {}; +}; + } // namespace RTC #endif diff --git a/webrtc/USAGE.md b/webrtc/USAGE.md new file mode 100644 index 00000000..67542f0a --- /dev/null +++ b/webrtc/USAGE.md @@ -0,0 +1,256 @@ +# WebRTC 使用说明 + +## WebRTC 架构 + +### 1. SFU 模式架构 (WHIP/WHEP) + +SFU 模式通过服务器中继媒体流,支持多路复用和转码: + +``` + WebRTC SFU 模式 (WHIP/WHEP) + + 推流端 (WHIP) 拉流端 (WHEP) + +----------------+ +-----------------+ + | Encoder | | Player | + | (Browser/ZLM) | | (Browser/ZLM) | + +----------------+ +-----------------+ + | | + | WHIP Protocol | WHEP Protocol + | (WebRTC ingest) | (WebRTC playback) + | | + v v + +-------------------------------------------------------------------+ + | ZLMediaKit Server | + +-------------------------------------------------------------------+ + - WHIP: WebRTC-HTTP Ingestion Protocol (推流) + - WHEP: WebRTC-HTTP Egress Protocol (拉流) +``` + +### 2. P2P 模式架构 + +P2P 模式允许客户端之间直接建立连接,减少服务器负载: +基于Websocket的自定义信令协议 + +``` + WebRTt WC P2P 模式 + + 客户端 A 客户端 B + +------------+ +-------------+ + | Browser/ZLM| | Browser/ZLM | + +------------+ +-------------+ + | | + | 1. 信令交换 (SDP Offer/Answer) | + | 2. ICE Candidate 交换 | + +---------------- -----+-----------------------+ + | | | + | +-----------------------+ | + | | ZLMediaKit Server | | + | | 信令服务器 (WebSocket) | | + | | STUN 服务器 | | + | | TURN 服务器 | | + | +-----------------------+ | + | | + +-----------------------------------------------+ + 直接P2P连接 +``` + +## HTTP API 接口 + +### 1. WebRTC 房间管理 + +#### `/index/api/addWebrtcRoomKeeper` +添加WebRTC到指定信令服务器,用于在信令服务器中维持房间连接。 + +**请求参数:** +- `secret`: 接口访问密钥 +- `server_host`: 信令服务器主机地址 +- `server_port`: 信令服务器端口 +- `room_id`: 房间ID,信令服务器会对该ID进行唯一性检查 + +#### `/index/api/delWebrtcRoomKeeper` +删除指定的信令服务器。 + +**请求参数:** +- `secret`: 接口访问密钥 +- `room_key`: 房间保持器的唯一标识符 + +#### `/index/api/listWebrtcRoomKeepers` +列出所有信令服务器。 + +**请求参数:** +- `secret`: 接口访问密钥 + +### 2. WebRTC 房间会话管理 + +#### `/index/api/listWebrtcRooms` +列出所有活跃的WebRTC Peer会话信息。 + +**请求参数:** +- `secret`: 接口访问密钥 + +### 3. WebRTC 推流和拉流接口 + +ZLMediaKit 支持通过标准的流代理接口来创建WebRTC推流和拉流,支持两种信令模式: + +##### `/index/api/addStreamProxy` - WebRTC 拉流代理 + +通过此接口可以创建WebRTC拉流代理,支持两种信令协议模式。 + +**请求参数:** +- `secret`: 接口访问密钥 +- `vhost`: 虚拟主机名,默认为 `__defaultVhost__` +- `app`: 应用名 +- `stream`: 流ID +- `url`: WebRTC源URL,支持两种格式 + +**WebRTC URL 格式:** + +1. **WHIP/WHEP 模式 (SFU)** - 标准HTTP信令协议: + ``` + # HTTP + webrtc://server_host:server_port/app/stream_id?signaling_protocols=0 + + # HTTPS (暂未实现) + webrtcs://server_host:server_port/app/stream_id?signaling_protocols=0 + ``` + +2. **WebSocket P2P 模式** - 自定义信令协议: + ``` + # WebSocket + webrtc://signaling_server_host:signaling_server_port/app/stream_id?signaling_protocols=1&peer_room_id=target_room_id + + # WebSocket Secure (暂未实现) + webrtcs://signaling_server_host:signaling_server_port/app/stream_id?signaling_protocols=1&peer_room_id=target_room_id + ``` + +**请求示例:** +```bash +# WHIP/WHEP 模式拉流 +curl -X POST "http://127.0.0.1/index/api/addStreamProxy" \ + -d "secret=your_secret" \ + -d "vhost=__defaultVhost__" \ + -d "app=live" \ + -d "stream=test" \ + -d "url=webrtc://source.server.com:80/live/source_stream?signaling_protocols=0" + +# P2P 模式拉流 +curl -X POST "http://127.0.0.1/index/api/addStreamProxy" \ + -d "secret=your_secret" \ + -d "vhost=__defaultVhost__" \ + -d "app=live" \ + -d "stream=test" \ + -d "url=webrtc://signaling.server.com:3000/live/source_stream??signaling_protocols=1%26peer_room_id=target_room_id" +``` + +#### `/index/api/addStreamPusherProxy` - WebRTC 推流代理 (暂未实现) + +通过此接口可以创建WebRTC推流代理,将现有流推送到WebRTC目标服务器。 + +**请求参数:** +- `secret`: 接口访问密钥 +- `schema`: 源流协议 (如: rtmp, rtsp, hls等) +- `vhost`: 虚拟主机名 +- `app`: 应用名 +- `stream`: 源流ID +- `dst_url`: WebRTC目标推流URL + +**WebRTC 推流 URL 格式:** + +1. **WHIP 模式 (SFU)** - 推流到支持WHIP的服务器: + ``` + # HTTP + webrtc://target_server:port/app/stream_id?signaling_protocols=0 + + # HTTPS (暂未实现) + webrtcs://target_server:port/app/stream_id?signaling_protocols=0 + ``` + +2. **WebSocket P2P 模式** - 推流到P2P房间 + ``` + # WebSocket + webrtc://signaling_server:port/app/stream_id?signaling_protocols=1&peer_room_id=target_room + # WebSocket Secure + webrtcs://signaling_server:port/app/stream_id?signaling_protocols=1&peer_room_id=target_room + ``` + +**请求示例:** +```bash +# 将RTSP流推送到WHIP服务器 +curl -X POST "http://127.0.0.1/index/api/addStreamPusherProxy" \ + -d "secret=your_secret" \ + -d "schema=rtsp" \ + -d "vhost=__defaultVhost__" \ + -d "app=live" \ + -d "stream=test" \ + -d "dst_url=webrtc://target.server.com:80/live/target_stream?signaling_protocols=0" + +# 将RTSP流推送到P2P房间 +curl -X POST "http://127.0.0.1/index/api/addStreamPusherProxy" \ + -d "secret=your_secret" \ + -d "schema=rtsp" \ + -d "vhost=__defaultVhost__" \ + -d "app=live" \ + -d "stream=test" \ + -d "dst_url=webrtc://signaling.server.com:3000/live/room_stream?signaling_protocols=1%26peer_room_id=target_room_id" +``` + +#### URL 参数说明 + +- `signaling_protocols`: 信令协议类型 + - `0`: WHIP/WHEP 模式(默认) + - **协议**: 基于HTTP的标准WebRTC信令协议 + - **应用场景**: SFU(选择性转发单元)模式,适合广播和多人会议 + - `1`: WebSocket P2P 模式 + - **协议**: 基于WebSocket的自定义信令协议 + - **应用场景**: 点对点直连,适合低延迟通话和私人通信 +- `peer_room_id`: P2P模式下的目标房间ID(仅P2P模式需要) + +### 4. WebRTC 代理播放器信息查询 + +#### `/index/api/getWebrtcProxyPlayerInfo` +获取WebRTC代理播放器的连接信息和状态。 + +**请求参数:** +- `secret`: 接口访问密钥 +- `key`: 代理播放器标识符 + + +## WebRTC 相关配置项 + +在 `config.ini` 中的 `[rtc]` 配置段: + +``` ini +[rtc] +#webrtc 信令服务器端口 +signalingPort=3000 +#STUN/TURN服务器端口 +icePort=3478 +#STUN/TURN端口是否使能TURN服务 +enableTurn=1 + +#TURN服务分配端口池 +portRange=50000-65000 + +#ICE传输策略:0=不限制(默认),1=仅支持Relay转发,2=仅支持P2P直连 +iceTransportPolicy=0 + +#STUN/TURN 服务Ice密码 +iceUfrag=ZLMediaKit +icePwd=ZLMediaKit +``` + +## Examples +- [zlm_peerconnection](https://gitee.com/libwebrtc_develop/libwebrtc/tree/feature-zlm/examples/zlm_peerconnection) + 一个基于libwebrtc 实现的zlm p2p 代理拉流简单示例 + +## 注意事项 + +1. **防火墙配置**: 确保 WebRTC 相关端口已开放 + - 信令端口: 3000 (默认) + - STUN/TURN 端口: 3478 (默认) + - TURN Alloc 端口范围: 50000-65000(默认) + +## 暂未实现的功能: +- Webrtc信令服务的安全校验 +- 自定义外部STUN/TURN 服务器的配置 +- webrtc代理推流 diff --git a/webrtc/Utils.hpp b/webrtc/Utils.hpp deleted file mode 100644 index d1386504..00000000 --- a/webrtc/Utils.hpp +++ /dev/null @@ -1,118 +0,0 @@ -/** -ISC License - -Copyright © 2015, Iñaki Baz Castillo - -Permission to use, copy, modify, and/or distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - */ - -#ifndef MS_UTILS_HPP -#define MS_UTILS_HPP - -#if defined(_WIN32) -#include -#include -#pragma comment (lib, "Ws2_32.lib") -#else -#include -#endif // defined(_WIN32) - -#include // PRIu64, etc -#include // size_t -#include // uint8_t, etc - -namespace Utils { - -class Byte { -public: - /** - * Getters below get value in Host Byte Order. - * Setters below set value in Network Byte Order. - */ - static uint8_t Get1Byte(const uint8_t *data, size_t i); - static uint16_t Get2Bytes(const uint8_t *data, size_t i); - static uint32_t Get3Bytes(const uint8_t *data, size_t i); - static uint32_t Get4Bytes(const uint8_t *data, size_t i); - static uint64_t Get8Bytes(const uint8_t *data, size_t i); - static void Set1Byte(uint8_t *data, size_t i, uint8_t value); - static void Set2Bytes(uint8_t *data, size_t i, uint16_t value); - static void Set3Bytes(uint8_t *data, size_t i, uint32_t value); - static void Set4Bytes(uint8_t *data, size_t i, uint32_t value); - static void Set8Bytes(uint8_t *data, size_t i, uint64_t value); - static uint16_t PadTo4Bytes(uint16_t size); - static uint32_t PadTo4Bytes(uint32_t size); -}; - -/* Inline static methods. */ - -inline uint8_t Byte::Get1Byte(const uint8_t *data, size_t i) { return data[i]; } - -inline uint16_t Byte::Get2Bytes(const uint8_t *data, size_t i) { - return uint16_t{data[i + 1]} | uint16_t{data[i]} << 8; -} - -inline uint32_t Byte::Get3Bytes(const uint8_t *data, size_t i) { - return uint32_t{data[i + 2]} | uint32_t{data[i + 1]} << 8 | uint32_t{data[i]} << 16; -} - -inline uint32_t Byte::Get4Bytes(const uint8_t *data, size_t i) { - return uint32_t{data[i + 3]} | uint32_t{data[i + 2]} << 8 | uint32_t{data[i + 1]} << 16 | - uint32_t{data[i]} << 24; -} - -inline uint64_t Byte::Get8Bytes(const uint8_t *data, size_t i) { - return uint64_t{Byte::Get4Bytes(data, i)} << 32 | Byte::Get4Bytes(data, i + 4); -} - -inline void Byte::Set1Byte(uint8_t *data, size_t i, uint8_t value) { data[i] = value; } - -inline void Byte::Set2Bytes(uint8_t *data, size_t i, uint16_t value) { - data[i + 1] = static_cast(value); - data[i] = static_cast(value >> 8); -} - -inline void Byte::Set3Bytes(uint8_t *data, size_t i, uint32_t value) { - data[i + 2] = static_cast(value); - data[i + 1] = static_cast(value >> 8); - data[i] = static_cast(value >> 16); -} - -inline void Byte::Set4Bytes(uint8_t *data, size_t i, uint32_t value) { - data[i + 3] = static_cast(value); - data[i + 2] = static_cast(value >> 8); - data[i + 1] = static_cast(value >> 16); - data[i] = static_cast(value >> 24); -} - -inline void Byte::Set8Bytes(uint8_t *data, size_t i, uint64_t value) { - data[i + 7] = static_cast(value); - data[i + 6] = static_cast(value >> 8); - data[i + 5] = static_cast(value >> 16); - data[i + 4] = static_cast(value >> 24); - data[i + 3] = static_cast(value >> 32); - data[i + 2] = static_cast(value >> 40); - data[i + 1] = static_cast(value >> 48); - data[i] = static_cast(value >> 56); -} - -inline uint16_t Byte::PadTo4Bytes(uint16_t size) { - // If size is not multiple of 32 bits then pad it. - if (size & 0x03) - return (size & 0xFFFC) + 4; - else - return size; -} - -}// namespace Utils - -#endif diff --git a/webrtc/WebRtcClient.cpp b/webrtc/WebRtcClient.cpp new file mode 100755 index 00000000..54608817 --- /dev/null +++ b/webrtc/WebRtcClient.cpp @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "Network/TcpClient.h" +#include "Common/config.h" +#include "Common/Parser.h" +#include "WebRtcClient.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +// # WebRTCUrl format +// ## whep/whip over http sfu: webrtc://server_host:server_port/{{app}}/{{streamid}} +// ## whep/whip over https sfu: webrtcs://server_host:server_port/{{app}}/{{streamid}} +// ## websocket p2p: webrtc://{{signaling_server_host}}:{{signaling_server_port}}/{{app}}/{{streamid}}?room_id={{peer_room_id}} +// ## websockets p2p: webrtcs://{{signaling_server_host}}:{{signaling_server_port}}/{{app}}/{{streamid}}?room_id={{peer_room_id}} +void WebRTCUrl::parse(const string &strUrl, bool isPlayer) { + DebugL << "url: " << strUrl; + _full_url = strUrl; + auto url = strUrl; + auto pos = url.find("?"); + if (pos != string::npos) { + _params = url.substr(pos + 1); + url.erase(pos); + } + + auto schema_pos = url.find("://"); + if (schema_pos != string::npos) { + auto schema = url.substr(0, schema_pos); + _is_ssl = strcasecmp(schema.data(), "webrtcs") == 0; + } else { + schema_pos = -3; + } + // set default port + _port = _is_ssl ? 443 : 80; + auto split_vec = split(url.substr(schema_pos + 3), "/"); + if (split_vec.size() > 0) { + splitUrl(split_vec[0], _host, _port); + _vhost = _host; + if (_vhost == "localhost" || isIP(_vhost.data())) { + // 如果访问的是localhost或ip,那么则为默认虚拟主机 + _vhost = DEFAULT_VHOST; + } + } + if (split_vec.size() > 1) { + _app = split_vec[1]; + } + if (split_vec.size() > 2) { + string stream_id; + for (size_t i = 2; i < split_vec.size(); ++i) { + stream_id.append(split_vec[i] + "/"); + } + if (stream_id.back() == '/') { + stream_id.pop_back(); + } + _stream = stream_id; + } + + // for vhost + auto kv = Parser::parseArgs(_params); + auto it = kv.find(VHOST_KEY); + if (it != kv.end()) { + _vhost = it->second; + } + + GET_CONFIG(bool, enableVhost, General::kEnableVhost); + if (!enableVhost || _vhost.empty()) { + // 如果关闭虚拟主机或者虚拟主机为空,则设置虚拟主机为默认 + _vhost = DEFAULT_VHOST; + } + + // for peer_room_id + it = kv.find("peer_room_id"); + if (it != kv.end()) { + _peer_room_id = it->second; + } + + it = kv.find("signaling_protocols"); + if (it != kv.end()) { + _signaling_protocols = (WebRtcTransport::SignalingProtocols)(stoi(it->second)); + } + + auto suffix = _host + ":" + to_string(_port); + suffix += (isPlayer ? "/index/api/whep" : "/index/api/whip"); + suffix += "?app=" + _app + "&stream=" + _stream; + if (!_params.empty()) { + suffix += "&" + _params; + } + if (_is_ssl) { + _negotiate_url = StrPrinter << "https://" << suffix << endl; + } else { + _negotiate_url = StrPrinter << "http://" << suffix << endl; + } +} + +//////////// WebRtcClient ////////////////////////// + +WebRtcClient::WebRtcClient(toolkit::EventPoller::Ptr poller) { + DebugL; + _poller = poller ? std::move(poller) : EventPollerPool::Instance().getPoller(); +} + +WebRtcClient::~WebRtcClient() { + doBye(); + DebugL; +} + +void WebRtcClient::startConnect() { + DebugL; + doNegotiate(); +} + +void WebRtcClient::connectivityCheck() { + DebugL; + return _transport->connectivityCheckForSFU(); +} + +void WebRtcClient::onNegotiateFinish() { + DebugL; + _is_negotiate_finished = true; + if (WebRtcTransport::SignalingProtocols::WEBSOCKET == _url._signaling_protocols) { + // P2P模式需要gathering candidates + gatheringCandidate(_peer->getIceServer()); + } else if (WebRtcTransport::SignalingProtocols::WHEP_WHIP == _url._signaling_protocols) { + // SFU模式不会存在IP不通的情况, answer中就携带了candidates, 直接进行connectivityCheck + connectivityCheck(); + } +} + +void WebRtcClient::doNegotiate() { + DebugL; + switch (_url._signaling_protocols) { + case WebRtcTransport::SignalingProtocols::WHEP_WHIP: return doNegotiateWhepOrWhip(); + case WebRtcTransport::SignalingProtocols::WEBSOCKET: return doNegotiateWebsocket(); + default: throw std::invalid_argument(StrPrinter << "not support signaling_protocols: " << (int)_url._signaling_protocols); + } +} + +void WebRtcClient::doNegotiateWhepOrWhip() { + DebugL << _url._negotiate_url; + + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + auto offer_sdp = _transport->createOfferSdp(); + DebugL << "send offer:\n" << offer_sdp; + + _negotiate = make_shared(); + _negotiate->setMethod("POST"); + _negotiate->addHeader("Content-Type", "application/sdp"); + _negotiate->setBody(std::move(offer_sdp)); + _negotiate->startRequester(_url._negotiate_url, [weak_self](const toolkit::SockException &ex, const Parser &response) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + if (ex) { + WarnL << "network err:" << ex; + strong_self->onResult(ex); + return; + } + + DebugL << "status:" << response.status() << "\r\n" + << "Location:\r\n" + << response.getHeader()["Location"] << "\r\nrecv answer:\n" + << response.content(); + strong_self->_url._delete_url = response.getHeader()["Location"]; + if ("201" != response.status()) { + strong_self->onResult(SockException(Err_other, response.content())); + return; + } + strong_self->_transport->setAnswerSdp(response.content()); + strong_self->onNegotiateFinish(); + }, getTimeOutSec()); +} + +void WebRtcClient::doNegotiateWebsocket() { + DebugL; +#if 0 + //TODO: 当前暂将每一路呼叫都使用一个独立的peer_connection,不复用 + _peer = getWebrtcRoomKeeper(_url._host, _url._port); + if (_peer) { + checkIn(); + return; + } +#endif + + // 未注册的,先增加注册流程,并在此次播放结束后注销 + InfoL << (StrPrinter << "register to signaling server " << _url._host << "::" << _url._port << " first"); + auto room_id = "ringing_" + makeRandStr(16); + _peer = make_shared(_url._host, _url._port, _url._is_ssl, room_id); + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + _peer->setOnConnect([weak_self](const SockException &ex) { + if (auto strong_self = weak_self.lock()) { + if (ex) { + strong_self->onResult(ex); + return; + } + + auto cb = [weak_self](const SockException &ex, const string &key) { + if (auto strong_self = weak_self.lock()) { + strong_self->checkIn(); + } + }; + strong_self->_peer->regist(cb); + } + }); + _peer->connect(); +} + +void WebRtcClient::checkIn() { + DebugL; + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + auto tuple = MediaTuple(_url._vhost, _url._app, _url._stream, _url._params); + _peer->checkIn(_url._peer_room_id, tuple, _transport->getIdentifier(), _transport->createOfferSdp(), isPlayer(), + [weak_self](const SockException &ex, const std::string &answer) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + if (ex) { + WarnL << "network err:" << ex; + strong_self->onResult(ex); + return; + } + + strong_self->_transport->setAnswerSdp(answer); + strong_self->onNegotiateFinish(); + }, getTimeOutSec()); +} + +void WebRtcClient::checkOut() { + DebugL; + auto tuple = MediaTuple(_url._vhost, _url._app, _url._stream); + if (_peer) { + _peer->checkOut(_url._peer_room_id); + _peer->unregist([](const SockException &ex) {}); + } +} + +void WebRtcClient::candidate(const std::string &candidate, const std::string &ufrag, const std::string &pwd) { + _peer->candidate(_transport->getIdentifier(), candidate, ufrag, pwd); +} + +void WebRtcClient::gatheringCandidate(IceServerInfo::Ptr ice_server) { + DebugL; + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + _transport->gatheringCandidate(ice_server, [weak_self](const std::string& transport_identifier, const std::string& candidate, + const std::string& ufrag, const std::string& pwd) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->candidate(candidate, ufrag, pwd); + }); +} + +void WebRtcClient::doBye() { + DebugL; + if (!_is_negotiate_finished) { + return; + } + + switch (_url._signaling_protocols) { + case WebRtcTransport::SignalingProtocols::WHEP_WHIP: return doByeWhepOrWhip(); + case WebRtcTransport::SignalingProtocols::WEBSOCKET: return checkOut(); + default: throw std::invalid_argument(StrPrinter << "not support signaling_protocols: " << (int)_url._signaling_protocols); + } + _is_negotiate_finished = false; +} + +void WebRtcClient::doByeWhepOrWhip() { + DebugL; + if (!_negotiate) { + return; + } + _negotiate->setMethod("DELETE"); + _negotiate->setBody(""); + _negotiate->startRequester(_url._delete_url, [](const toolkit::SockException &ex, const Parser &response) { + if (ex) { + WarnL << "network err:" << ex; + return; + } + DebugL << "status:" << response.status(); + }, getTimeOutSec()); +} + +float WebRtcClient::getTimeOutSec() { + GET_CONFIG(uint32_t, timeout, Rtc::kTimeOutSec); + if (timeout <= 0) { + WarnL << "config rtc. " << Rtc::kTimeOutSec << ": " << timeout << " not vaild"; + return 5.0; + } + return (float)timeout; +} + +} /* namespace mediakit */ diff --git a/webrtc/WebRtcClient.h b/webrtc/WebRtcClient.h new file mode 100755 index 00000000..914b9117 --- /dev/null +++ b/webrtc/WebRtcClient.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTC_CLIENT_H +#define ZLMEDIAKIT_WEBRTC_CLIENT_H + +#include "Http/HttpRequester.h" +#include "Sdp.h" +#include "WebRtcTransport.h" +#include "WebRtcSignalingPeer.h" +#include +#include + +namespace mediakit { + +// 解析webrtc 信令url的工具类 +class WebRTCUrl { +public: + bool _is_ssl; + std::string _full_url; + std::string _negotiate_url; // for whep or whip + std::string _delete_url; // for whep or whip + std::string _target_secret; + std::string _params; + std::string _host; + uint16_t _port; + std::string _vhost; + std::string _app; + std::string _stream; + WebRtcTransport::SignalingProtocols _signaling_protocols = WebRtcTransport::SignalingProtocols::WHEP_WHIP; + std::string _peer_room_id; // peer room_id + +public: + void parse(const std::string &url, bool isPlayer); + +private: +}; + +// 实现了webrtc代理功能 +class WebRtcClient : public std::enable_shared_from_this { +public: + using Ptr = std::shared_ptr; + + WebRtcClient(toolkit::EventPoller::Ptr poller); + virtual ~WebRtcClient(); + + const toolkit::EventPoller::Ptr &getPoller() const { return _poller; } + void setPoller(toolkit::EventPoller::Ptr poller) { _poller = std::move(poller); } + + // 获取WebRTC transport,用于API查询 + const WebRtcTransport::Ptr &getWebRtcTransport() const { return _transport; } + +protected: + virtual bool isPlayer() = 0; + virtual void startConnect(); + virtual void onResult(const toolkit::SockException &ex) = 0; + virtual void onNegotiateFinish(); + virtual float getTimeOutSec(); + + void doNegotiate(); + void doNegotiateWebsocket(); + void doNegotiateWhepOrWhip(); + void checkIn(); + void doBye(); + void doByeWhepOrWhip(); + void checkOut(); + + void gatheringCandidate(IceServerInfo::Ptr ice_server); + void connectivityCheck(); + void candidate(const std::string &candidate, const std::string &ufrag, const std::string &pwd); + +protected: + toolkit::EventPoller::Ptr _poller; + + // for _negotiate_sdp + WebRTCUrl _url; + HttpRequester::Ptr _negotiate = nullptr; + WebRtcSignalingPeer::Ptr _peer = nullptr; + WebRtcTransport::Ptr _transport = nullptr; + bool _is_negotiate_finished = false; + +}; + +} /*namespace mediakit */ +#endif /* ZLMEDIAKIT_WEBRTC_CLIENT_H */ diff --git a/webrtc/WebRtcEchoTest.cpp b/webrtc/WebRtcEchoTest.cpp index fcbd1266..f74a0187 100644 --- a/webrtc/WebRtcEchoTest.cpp +++ b/webrtc/WebRtcEchoTest.cpp @@ -10,6 +10,8 @@ #include "WebRtcEchoTest.h" +using namespace toolkit; + namespace mediakit { WebRtcEchoTest::Ptr WebRtcEchoTest::create(const EventPoller::Ptr &poller) { diff --git a/webrtc/WebRtcEchoTest.h b/webrtc/WebRtcEchoTest.h index e6249ff2..fffb5292 100644 --- a/webrtc/WebRtcEchoTest.h +++ b/webrtc/WebRtcEchoTest.h @@ -18,7 +18,7 @@ namespace mediakit { class WebRtcEchoTest : public WebRtcTransportImp { public: using Ptr = std::shared_ptr; - static Ptr create(const EventPoller::Ptr &poller); + static Ptr create(const toolkit::EventPoller::Ptr &poller); protected: ///////WebRtcTransportImp override/////// @@ -31,7 +31,7 @@ protected: void onBeforeEncryptRtcp(const char *buf, int &len, void *ctx) override {}; private: - WebRtcEchoTest(const EventPoller::Ptr &poller); + WebRtcEchoTest(const toolkit::EventPoller::Ptr &poller); }; }// namespace mediakit diff --git a/webrtc/WebRtcPlayer.cpp b/webrtc/WebRtcPlayer.cpp index 158042d3..4f11f17d 100644 --- a/webrtc/WebRtcPlayer.cpp +++ b/webrtc/WebRtcPlayer.cpp @@ -1,168 +1,338 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#include "WebRtcPlayer.h" - -#include "Common/config.h" -#include "Extension/Factory.h" -#include "Util/base64.h" - -using namespace std; - -namespace mediakit { - -WebRtcPlayer::Ptr WebRtcPlayer::create(const EventPoller::Ptr &poller, - const RtspMediaSource::Ptr &src, - const MediaInfo &info) { - WebRtcPlayer::Ptr ret(new WebRtcPlayer(poller, src, info), [](WebRtcPlayer *ptr) { - ptr->onDestory(); - delete ptr; - }); - ret->onCreate(); - return ret; -} - -WebRtcPlayer::WebRtcPlayer(const EventPoller::Ptr &poller, - const RtspMediaSource::Ptr &src, - const MediaInfo &info) : WebRtcTransportImp(poller) { - _media_info = info; - _play_src = src; - CHECK(src); - - GET_CONFIG(bool, direct_proxy, Rtsp::kDirectProxy); - _send_config_frames_once = direct_proxy; -} - -void WebRtcPlayer::onStartWebRTC() { - auto playSrc = _play_src.lock(); - if(!playSrc){ - onShutdown(SockException(Err_shutdown, "rtsp media source was shutdown")); - return ; - } - WebRtcTransportImp::onStartWebRTC(); - if (canSendRtp()) { - playSrc->pause(false); - _reader = playSrc->getRing()->attach(getPoller(), true); - weak_ptr weak_self = static_pointer_cast(shared_from_this()); - weak_ptr weak_session = static_pointer_cast(getSession()); - _reader->setGetInfoCB([weak_session]() { - Any ret; - ret.set(static_pointer_cast(weak_session.lock())); - return ret; - }); - _reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt) { - auto strong_self = weak_self.lock(); - if (!strong_self) { - return; - } - - if (strong_self->_send_config_frames_once && !pkt->empty()) { - const auto &first_rtp = pkt->front(); - strong_self->sendConfigFrames(first_rtp->getSeq(), first_rtp->sample_rate, first_rtp->getStamp(), first_rtp->ntp_stamp); - strong_self->_send_config_frames_once = false; - } - - size_t i = 0; - pkt->for_each([&](const RtpPacket::Ptr &rtp) { - //TraceL<<"send track type:"<type<<" ts:"<getStamp()<<" ntp:"<ntp_stamp<<" size:"<getPayloadSize()<<" i:"<onSendRtp(rtp, ++i == pkt->size()); - }); - }); - _reader->setDetachCB([weak_self]() { - auto strong_self = weak_self.lock(); - if (!strong_self) { - return; - } - strong_self->onShutdown(SockException(Err_shutdown, "rtsp ring buffer detached")); - }); - - _reader->setMessageCB([weak_self] (const toolkit::Any &data) { - auto strong_self = weak_self.lock(); - if (!strong_self) { - return; - } - if (data.is()) { - auto &buffer = data.get(); - // PPID 51: 文本string [AUTO-TRANSLATED:69a8cf81] - // PPID 51: Text string - // PPID 53: 二进制 [AUTO-TRANSLATED:faf00c3e] - // PPID 53: Binary - strong_self->sendDatachannel(0, 51, buffer.data(), buffer.size()); - } else { - WarnL << "Send unknown message type to webrtc player: " << data.type_name(); - } - }); - } -} -void WebRtcPlayer::onDestory() { - auto duration = getDuration(); - auto bytes_usage = getBytesUsage(); - // 流量统计事件广播 [AUTO-TRANSLATED:6b0b1234] - // Traffic statistics event broadcast - GET_CONFIG(uint32_t, iFlowThreshold, General::kFlowThreshold); - if (_reader && getSession()) { - WarnL << "RTC播放器(" << _media_info.shortUrl() << ")结束播放,耗时(s):" << duration; - if (bytes_usage >= iFlowThreshold * 1024) { - NOTICE_EMIT(BroadcastFlowReportArgs, Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, true, *getSession()); - } - } - WebRtcTransportImp::onDestory(); -} - -void WebRtcPlayer::onRtcConfigure(RtcConfigure &configure) const { - auto playSrc = _play_src.lock(); - if(!playSrc){ - return ; - } - WebRtcTransportImp::onRtcConfigure(configure); - // 这是播放 [AUTO-TRANSLATED:d93c019e] - // This is playing - configure.audio.direction = configure.video.direction = RtpDirection::sendonly; - configure.setPlayRtspInfo(playSrc->getSdp()); -} - -void WebRtcPlayer::sendConfigFrames(uint32_t before_seq, uint32_t sample_rate, uint32_t timestamp, uint64_t ntp_timestamp) { - auto play_src = _play_src.lock(); - if (!play_src) { - return; - } - SdpParser parser(play_src->getSdp()); - auto video_sdp = parser.getTrack(TrackVideo); - if (!video_sdp) { - return; - } - auto video_track = dynamic_pointer_cast(Factory::getTrackBySdp(video_sdp)); - if (!video_track) { - return; - } - auto frames = video_track->getConfigFrames(); - if (frames.empty()) { - return; - } - auto encoder = mediakit::Factory::getRtpEncoderByCodecId(video_track->getCodecId(), 0); - if (!encoder) { - return; - } - - GET_CONFIG(uint32_t, video_mtu, Rtp::kVideoMtuSize); - encoder->setRtpInfo(0, video_mtu, sample_rate, 0, 0, 0); - - auto seq = before_seq - frames.size(); - for (const auto &frame : frames) { - auto rtp = encoder->getRtpInfo().makeRtp(TrackVideo, frame->data() + frame->prefixSize(), frame->size() - frame->prefixSize(), false, 0); - auto header = rtp->getHeader(); - header->seq = htons(seq++); - header->stamp = htonl(timestamp); - rtp->ntp_stamp = ntp_timestamp; - onSendRtp(rtp, false); - } -} - +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcPlayer.h" + +#include "Common/config.h" +#include "Extension/Factory.h" +#include "Util/base64.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +namespace Rtc { +#define RTC_FIELD "rtc." +const string kBfilter = RTC_FIELD "bfilter"; +static onceToken token([]() { mINI::Instance()[kBfilter] = 0; }); +} // namespace Rtc + +H264BFrameFilter::H264BFrameFilter() + : _last_seq(0) + , _last_stamp(0) + , _first_packet(true) {} + +RtpPacket::Ptr H264BFrameFilter::processPacket(const RtpPacket::Ptr &packet) { + if (!packet) { + return nullptr; + } + + if (isH264BFrame(packet)) { + return nullptr; + } + + auto cur_stamp = packet->getStamp(); + auto cur_seq = packet->getSeq(); + + if (_first_packet) { + _first_packet = false; + _last_seq = cur_seq; + _last_stamp = cur_stamp; + } + + // 处理时间戳连续性问题 + if (cur_stamp < _last_stamp) { + return nullptr; + } + _last_stamp = cur_stamp; + + // 处理 seq 连续性问题 + if (cur_seq > _last_seq + 4) { + RtpHeader *header = packet->getHeader(); + _last_seq = (_last_seq + 1) & 0xFFFF; + header->seq = htons(_last_seq); + } + + return packet; +} + +bool H264BFrameFilter::isH264BFrame(const RtpPacket::Ptr &packet) const { + uint8_t *payload = packet->getPayload(); + size_t payload_size = packet->getPayloadSize(); + + if (payload_size < 1) { + return false; + } + + uint8_t nal_unit_type = payload[0] & 0x1F; + switch (nal_unit_type) { + case 24: // STAP-A + return handleStapA(payload, payload_size); + case 28: // FU-A + return handleFua(payload, payload_size); + default: + if (nal_unit_type < 24) { + return isBFrameByNalType(nal_unit_type, payload + 1, payload_size - 1); + } + return false; + } +} + +bool H264BFrameFilter::handleStapA(const uint8_t *payload, size_t payload_size) const { + size_t offset = 1; + while (offset + 2 <= payload_size) { + uint16_t nalu_size = (payload[offset] << 8) | payload[offset + 1]; + offset += 2; + if (offset + nalu_size > payload_size || nalu_size < 1) { + return false; + } + uint8_t original_nal_type = payload[offset] & 0x1F; + if (original_nal_type < 24) { + if (isBFrameByNalType(original_nal_type, payload + offset + 1, nalu_size - 1)) { + return true; + } + } + offset += nalu_size; + } + return false; +} + +bool H264BFrameFilter::handleFua(const uint8_t *payload, size_t payload_size) const { + if (payload_size < 2) { + return false; + } + uint8_t fu_header = payload[1]; + uint8_t original_nal_type = fu_header & 0x1F; + bool start_bit = fu_header & 0x80; + if (start_bit) { + return isBFrameByNalType(original_nal_type, payload + 2, payload_size - 2); + } + return false; +} + +bool H264BFrameFilter::isBFrameByNalType(uint8_t nal_type, const uint8_t *data, size_t size) const { + if (size < 1) { + return false; + } + + if (nal_type != NAL_NIDR && nal_type != NAL_PARTITION_A && nal_type != NAL_PARTITION_B && nal_type != NAL_PARTITION_C) { + return false; + } + + uint8_t slice_type = extractSliceType(data, size); + return slice_type == H264SliceTypeB || slice_type == H264SliceTypeB1; +} + +int H264BFrameFilter::decodeExpGolomb(const uint8_t *data, size_t size, size_t &bitPos) const { + if (bitPos >= size * 8) + return -1; + + int leadingZeroBits = 0; + while (bitPos < size * 8 && !getBit(data, bitPos++)) { + leadingZeroBits++; + } + + int result = (1 << leadingZeroBits) - 1; + for (int i = 0; i < leadingZeroBits; i++) { + if (bitPos < size * 8) { + result += getBit(data, bitPos++) << (leadingZeroBits - i - 1); + } + } + + return result; +} + +int H264BFrameFilter::getBit(const uint8_t *data, size_t pos) const { + size_t byteIndex = pos / 8; + size_t bitOffset = pos % 8; + uint8_t byte = data[byteIndex]; + return (byte >> (7 - bitOffset)) & 0x01; +} + +uint8_t H264BFrameFilter::extractSliceType(const uint8_t *data, size_t size) const { + size_t bitPos = 0; + int first_mb_in_slice = decodeExpGolomb(data, size, bitPos); + int slice_type = decodeExpGolomb(data, size, bitPos); + + if (slice_type >= 0 && slice_type <= 9) { + return static_cast(slice_type); + } + return -1; +} + +WebRtcPlayer::Ptr WebRtcPlayer::create(const EventPoller::Ptr &poller, + const RtspMediaSource::Ptr &src, + const MediaInfo &info, + WebRtcTransport::Role role, + WebRtcTransport::SignalingProtocols signaling_protocols) { + WebRtcPlayer::Ptr ret(new WebRtcPlayer(poller, src, info), [](WebRtcPlayer *ptr) { + ptr->onDestory(); + delete ptr; + }); + ret->setRole(role); + ret->setSignalingProtocols(signaling_protocols); + ret->onCreate(); + return ret; +} + +WebRtcPlayer::WebRtcPlayer(const EventPoller::Ptr &poller, + const RtspMediaSource::Ptr &src, + const MediaInfo &info) : WebRtcTransportImp(poller) { + _media_info = info; + _play_src = src; + CHECK(src); + + GET_CONFIG(bool, direct_proxy, Rtsp::kDirectProxy); + _send_config_frames_once = direct_proxy; + + GET_CONFIG(bool, enable, Rtc::kBfilter); + _bfliter_flag = enable; + _is_h264 = false; + _bfilter = std::make_shared(); +} + +void WebRtcPlayer::onStartWebRTC() { + auto playSrc = _play_src.lock(); + if (!playSrc) { + onShutdown(SockException(Err_shutdown, "rtsp media source was shutdown")); + return; + } + WebRtcTransportImp::onStartWebRTC(); + if (canSendRtp()) { + playSrc->pause(false); + _reader = playSrc->getRing()->attach(getPoller(), true); + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + weak_ptr weak_session = static_pointer_cast(getSession()); + _reader->setGetInfoCB([weak_session]() { + Any ret; + ret.set(static_pointer_cast(weak_session.lock())); + return ret; + }); + _reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + + if (strong_self->_send_config_frames_once && !pkt->empty()) { + const auto &first_rtp = pkt->front(); + strong_self->sendConfigFrames(first_rtp->getSeq(), first_rtp->sample_rate, first_rtp->getStamp(), first_rtp->ntp_stamp); + strong_self->_send_config_frames_once = false; + } + + size_t i = 0; + pkt->for_each([&](const RtpPacket::Ptr &rtp) { + if (strong_self->_bfliter_flag) { + if (TrackVideo == rtp->type && strong_self->_is_h264) { + auto rtp_filter = strong_self->_bfilter->processPacket(rtp); + if (rtp_filter) { + strong_self->onSendRtp(rtp_filter, ++i == pkt->size()); + } + } else { + strong_self->onSendRtp(rtp, ++i == pkt->size()); + } + } else { + strong_self->onSendRtp(rtp, ++i == pkt->size()); + } + }); + }); + _reader->setDetachCB([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->onShutdown(SockException(Err_shutdown, "rtsp ring buffer detached")); + }); + + _reader->setMessageCB([weak_self](const toolkit::Any &data) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + if (data.is()) { + auto &buffer = data.get(); + // PPID 51: 文本string [AUTO-TRANSLATED:69a8cf81] + // PPID 51: Text string + // PPID 53: 二进制 [AUTO-TRANSLATED:faf00c3e] + // PPID 53: Binary + strong_self->sendDatachannel(0, 51, buffer.data(), buffer.size()); + } else { + WarnL << "Send unknown message type to webrtc player: " << data.type_name(); + } + }); + } +} +void WebRtcPlayer::onDestory() { + auto duration = getDuration(); + auto bytes_usage = getBytesUsage(); + // 流量统计事件广播 [AUTO-TRANSLATED:6b0b1234] + // Traffic statistics event broadcast + GET_CONFIG(uint32_t, iFlowThreshold, General::kFlowThreshold); + if (_reader && getSession()) { + WarnL << "RTC播放器(" << _media_info.shortUrl() << ")结束播放,耗时(s):" << duration; + if (bytes_usage >= iFlowThreshold * 1024) { + NOTICE_EMIT(BroadcastFlowReportArgs, Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, true, *getSession()); + } + } + WebRtcTransportImp::onDestory(); +} + +void WebRtcPlayer::onRtcConfigure(RtcConfigure &configure) const { + auto playSrc = _play_src.lock(); + if (!playSrc) { + return; + } + WebRtcTransportImp::onRtcConfigure(configure); + // 这是播放 [AUTO-TRANSLATED:d93c019e] + // This is playing + configure.audio.direction = configure.video.direction = RtpDirection::sendonly; + configure.setPlayRtspInfo(playSrc->getSdp()); +} + +void WebRtcPlayer::sendConfigFrames(uint32_t before_seq, uint32_t sample_rate, uint32_t timestamp, uint64_t ntp_timestamp) { + auto play_src = _play_src.lock(); + if (!play_src) { + return; + } + SdpParser parser(play_src->getSdp()); + auto video_sdp = parser.getTrack(TrackVideo); + if (!video_sdp) { + return; + } + auto video_track = dynamic_pointer_cast(Factory::getTrackBySdp(video_sdp)); + if (!video_track) { + return; + } + _is_h264 = video_track->getCodecId() == CodecH264; + auto frames = video_track->getConfigFrames(); + if (frames.empty()) { + return; + } + auto encoder = mediakit::Factory::getRtpEncoderByCodecId(video_track->getCodecId(), 0); + if (!encoder) { + return; + } + + GET_CONFIG(uint32_t, video_mtu, Rtp::kVideoMtuSize); + encoder->setRtpInfo(0, video_mtu, sample_rate, 0, 0, 0); + + auto seq = before_seq - frames.size(); + for (const auto &frame : frames) { + auto rtp = encoder->getRtpInfo().makeRtp(TrackVideo, frame->data() + frame->prefixSize(), frame->size() - frame->prefixSize(), false, 0); + auto header = rtp->getHeader(); + header->seq = htons(seq++); + header->stamp = htonl(timestamp); + rtp->ntp_stamp = ntp_timestamp; + onSendRtp(rtp, false); + } +} + }// namespace mediakit \ No newline at end of file diff --git a/webrtc/WebRtcPlayer.h b/webrtc/WebRtcPlayer.h index a964b380..4e47b8ff 100644 --- a/webrtc/WebRtcPlayer.h +++ b/webrtc/WebRtcPlayer.h @@ -1,54 +1,165 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_WEBRTCPLAYER_H -#define ZLMEDIAKIT_WEBRTCPLAYER_H - -#include "WebRtcTransport.h" -#include "Rtsp/RtspMediaSource.h" - -namespace mediakit { - -class WebRtcPlayer : public WebRtcTransportImp { -public: - using Ptr = std::shared_ptr; - static Ptr create(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info); - MediaInfo getMediaInfo() { return _media_info; } - -protected: - ///////WebRtcTransportImp override/////// - void onStartWebRTC() override; - void onDestory() override; - void onRtcConfigure(RtcConfigure &configure) const override; - -private: - WebRtcPlayer(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info); - - void sendConfigFrames(uint32_t before_seq, uint32_t sample_rate, uint32_t timestamp, uint64_t ntp_timestamp); - -private: - // 媒体相关元数据 [AUTO-TRANSLATED:f4cf8045] - // Media related metadata - MediaInfo _media_info; - // 播放的rtsp源 [AUTO-TRANSLATED:9963eed1] - // Playing rtsp source - std::weak_ptr _play_src; - - // rtp 直接转发情况下通常会缺少 sps/pps, 在转发 rtp 前, 先发送一次相关帧信息, 部分情况下是可以播放的 [AUTO-TRANSLATED:65fdf16a] - // In the case of direct RTP forwarding, sps/pps is usually missing. Before forwarding RTP, send the relevant frame information once. In some cases, it can be played. - bool _send_config_frames_once { false }; - - // 播放rtsp源的reader对象 [AUTO-TRANSLATED:7b305055] - // Reader object for playing rtsp source - RtspMediaSource::RingType::RingReader::Ptr _reader; -}; - -}// namespace mediakit -#endif // ZLMEDIAKIT_WEBRTCPLAYER_H +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTCPLAYER_H +#define ZLMEDIAKIT_WEBRTCPLAYER_H + +#include "WebRtcTransport.h" +#include "Rtsp/RtspMediaSource.h" + +namespace mediakit { +/** + * @brief H.264 B 帧过滤器 + * 用于从 H.264 RTP 流中移除 B 帧 + */ +class H264BFrameFilter { +public: + /** + * ISO_IEC_14496-10-AVC-2012 + * Table 7-6 – Name association to slice_type + */ + enum H264SliceType { + H264SliceTypeP = 0, + H264SliceTypeB = 1, + H264SliceTypeI = 2, + H264SliceTypeSP = 3, + H264SliceTypeSI = 4, + H264SliceTypeP1 = 5, + H264SliceTypeB1 = 6, + H264SliceTypeI1 = 7, + H264SliceTypeSP1 = 8, + H264SliceTypeSI1 = 9, + }; + + enum H264NALUType { + NAL_NIDR = 1, + NAL_PARTITION_A = 2, + NAL_PARTITION_B = 3, + NAL_PARTITION_C = 4, + NAL_IDR = 5, + }; + + H264BFrameFilter(); + + ~H264BFrameFilter() = default; + + /** + * @brief 处理单个 RTP 包,移除 B 帧 + * @param packet 输入的 RTP 包 + * @return 如果不是 B 帧则返回原包,否则返回 nullptr + */ + RtpPacket::Ptr processPacket(const RtpPacket::Ptr &packet); + +private: + /** + * @brief 判断 RTP 包是否包含 H.264 的 B 帧 + * @param packet RTP 包 + * @return 如果是 B 帧返回 true,否则返回 false + */ + bool isH264BFrame(const RtpPacket::Ptr &packet) const; + + /** + * @brief 根据 NAL 类型和数据判断是否是 B 帧 + * @param nal_type NAL 单元类型 + * @param data NAL 单元数据(不含 NAL 头) + * @param size 数据大小 + * @return 如果是 B 帧返回 true,否则返回 false + */ + bool isBFrameByNalType(uint8_t nal_type, const uint8_t *data, size_t size) const; + + /** + * @brief 解析指数哥伦布编码 + * @param data 数据缓冲区 + * @param size 缓冲区大小 + * @param bits_offset 位偏移量 + * @return 解析出的数值 + */ + int decodeExpGolomb(const uint8_t *data, size_t size, size_t &bitPos) const; + + /** + * @brief 从比特流中读取位 + * @param data 数据缓冲区 + * @param size 缓冲区大小 + * @return 读取的位值(0 或 1) + */ + int getBit(const uint8_t *data, size_t size) const; + + /** + * @brief 提取切片类型值 + * @param data 数据缓冲区 + * @param size 缓冲区大小 + * @return 切片类型值 + */ + uint8_t extractSliceType(const uint8_t *data, size_t size) const; + + /** + * @brief 处理FU-A分片 + * @param payload 数据缓冲区 + * @param payload_size 缓冲区大小 + * @return 如果是 B 帧返回 true,否则返回 false + */ + bool handleFua(const uint8_t *payload, size_t payload_size) const; + + /** + * @brief 处理 STAP-A 组合包 + * @param payload 数据缓冲区 + * @param payload_size 缓冲区大小 + * @return 如果是 B 帧返回 true,否则返回 false + */ + bool handleStapA(const uint8_t *payload, size_t payload_size) const; + + +private: + uint16_t _last_seq; // 维护输出流的序列号 + uint32_t _last_stamp; // 维护输出流的时间戳 + bool _first_packet; // 是否是第一个包的标记 +}; + +class WebRtcPlayer : public WebRtcTransportImp { +public: + using Ptr = std::shared_ptr; + static Ptr create(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info, + WebRtcTransport::Role role, WebRtcTransport::SignalingProtocols signaling_protocols); + MediaInfo getMediaInfo() { return _media_info; } + +protected: + ///////WebRtcTransportImp override/////// + void onStartWebRTC() override; + void onDestory() override; + void onRtcConfigure(RtcConfigure &configure) const override; + +private: + WebRtcPlayer(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info); + + void sendConfigFrames(uint32_t before_seq, uint32_t sample_rate, uint32_t timestamp, uint64_t ntp_timestamp); + +private: + // 媒体相关元数据 [AUTO-TRANSLATED:f4cf8045] + // Media related metadata + MediaInfo _media_info; + // 播放的rtsp源 [AUTO-TRANSLATED:9963eed1] + // Playing rtsp source + std::weak_ptr _play_src; + + // rtp 直接转发情况下通常会缺少 sps/pps, 在转发 rtp 前, 先发送一次相关帧信息, 部分情况下是可以播放的 [AUTO-TRANSLATED:65fdf16a] + // In the case of direct RTP forwarding, sps/pps is usually missing. Before forwarding RTP, send the relevant frame information once. In some cases, it can be played. + bool _send_config_frames_once { false }; + + // 播放rtsp源的reader对象 [AUTO-TRANSLATED:7b305055] + // Reader object for playing rtsp source + RtspMediaSource::RingType::RingReader::Ptr _reader; + + bool _is_h264 { false }; + bool _bfliter_flag { false }; + std::shared_ptr _bfilter; +}; + +}// namespace mediakit +#endif // ZLMEDIAKIT_WEBRTCPLAYER_H diff --git a/webrtc/WebRtcProxyPlayer.cpp b/webrtc/WebRtcProxyPlayer.cpp new file mode 100755 index 00000000..4b403d54 --- /dev/null +++ b/webrtc/WebRtcProxyPlayer.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcProxyPlayer.h" +#include "WebRtcProxyPlayerImp.h" +#include "WebRtcPusher.h" +#include "Common/config.h" +#include "Http/HlsPlayer.h" +#include "Rtsp/RtspMediaSourceImp.h" + +using namespace toolkit; +using namespace std; + +namespace mediakit { + +WebRtcProxyPlayer::WebRtcProxyPlayer(const EventPoller::Ptr &poller) + : WebRtcClient(poller) { + DebugL; +} + +WebRtcProxyPlayer::~WebRtcProxyPlayer(void) { + DebugL; +} + +void WebRtcProxyPlayer::play(const string &strUrl) { + DebugL; + try { + _url.parse(strUrl, isPlayer()); + } catch (std::exception &ex) { + onResult(SockException(Err_other, StrPrinter << "illegal webrtc url:" << ex.what())); + return; + } + + startConnect(); +} + +void WebRtcProxyPlayer::teardown() { + DebugL; + doBye(); +} + +void WebRtcProxyPlayer::pause(bool bPause) { + DebugL; +} + +void WebRtcProxyPlayer::speed(float speed) { + DebugL; +} + +float WebRtcProxyPlayer::getTimeOutSec() { + auto timeoutMS = (*this)[Client::kTimeoutMS].as(); + return (float)timeoutMS / (float)1000; +} + +void WebRtcProxyPlayer::onNegotiateFinish() { + DebugL; + onResult(SockException(Err_success, "webrtc play success")); + WebRtcClient::onNegotiateFinish(); +} + +/////////////////////////////////////////////////// +// WebRtcProxyPlayerImp + +void WebRtcProxyPlayerImp::startConnect() { + DebugL; + MediaInfo info(_url._full_url); + ProtocolOption option; + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + _transport = WebRtcPlayerClient::create(getPoller(), WebRtcTransport::Role::CLIENT, _url._signaling_protocols); + _transport->setOnShutdown([weak_self](const SockException &ex) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->onResult(ex); + }); + WebRtcClient::startConnect(); +} + +void WebRtcProxyPlayerImp::onResult(const SockException &ex) { + if (!ex) { + // 播放成功 + _benchmark_mode = (*this)[Client::kBenchmarkMode].as(); + + WebRtcPlayerClient::Ptr transport = std::dynamic_pointer_cast(_transport); + auto media_src = dynamic_pointer_cast(_media_src); + transport->setMediaSource(media_src); + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + if (!ex) { + transport->setOnStartWebRTC([weak_self, ex]() { + if (auto strong_self = weak_self.lock()) { + strong_self->onPlayResult(ex); + } + }); + } + } else { + WarnL << ex.getErrCode() << " " << ex.what(); + if (ex.getErrCode() == Err_shutdown) { + // 主动shutdown的,不触发回调 + return; + } + + if (!_is_negotiate_finished) { + onPlayResult(ex); + } else { + onShutdown(ex); + } + } +} + +std::vector WebRtcProxyPlayerImp::getTracks(bool ready /*= true*/) const { + auto transport = static_pointer_cast(_transport); + return transport ? transport->getTracks(ready) : Super::getTracks(ready); +} + +void WebRtcProxyPlayerImp::addTrackCompleted() { +} + +} /* namespace mediakit */ diff --git a/webrtc/WebRtcProxyPlayer.h b/webrtc/WebRtcProxyPlayer.h new file mode 100755 index 00000000..92f59424 --- /dev/null +++ b/webrtc/WebRtcProxyPlayer.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTC_PROXY_PLAYER_H +#define ZLMEDIAKIT_WEBRTC_PROXY_PLAYER_H + +#include "Network/Socket.h" +#include "Player/PlayerBase.h" +#include "Poller/Timer.h" +#include "Util/TimeTicker.h" +#include "WebRtcClient.h" +#include +#include + +namespace mediakit { + +// 实现了webrtc代理拉流功能 +class WebRtcProxyPlayer + : public PlayerBase , public WebRtcClient { +public: + using Ptr = std::shared_ptr; + + WebRtcProxyPlayer(const toolkit::EventPoller::Ptr &poller); + ~WebRtcProxyPlayer() override; + + //// PlayerBase override//// + void play(const std::string &strUrl) override; + void teardown() override; + void pause(bool pause) override; + void speed(float speed) override; + + std::shared_ptr getSockInfo() const override { + return getWebRtcTransport() ? getWebRtcTransport()->getSession() : nullptr; + } + size_t getRecvSpeed() override { + return getWebRtcTransport() ? getWebRtcTransport()->getRecvSpeed() : 0; + } + size_t getRecvTotalBytes() override { + return getWebRtcTransport() ? getWebRtcTransport()->getRecvTotalBytes() : 0; + } + +protected: + + //// WebRtcClient override//// + bool isPlayer() override {return true;} + float getTimeOutSec() override; + void onNegotiateFinish() override; + +protected: + //是否为性能测试模式 + bool _benchmark_mode = false; + + //超时功能实现 + toolkit::Ticker _recv_ticker; + std::shared_ptr _check_timer; +}; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_WEBRTC_PROXY_PLAYER_H */ diff --git a/webrtc/WebRtcProxyPlayerImp.h b/webrtc/WebRtcProxyPlayerImp.h new file mode 100755 index 00000000..8b536202 --- /dev/null +++ b/webrtc/WebRtcProxyPlayerImp.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTC_PROXY_PLAYER_IMP_H +#define ZLMEDIAKIT_WEBRTC_PROXY_PLAYER_IMP_H + +#include "WebRtcProxyPlayer.h" + +namespace mediakit { + +class WebRtcProxyPlayerImp + : public PlayerImp + , private TrackListener { +public: + using Ptr = std::shared_ptr; + using Super = PlayerImp; + + WebRtcProxyPlayerImp(const toolkit::EventPoller::Ptr &poller) : Super(poller) {} + ~WebRtcProxyPlayerImp() override { DebugL; } + +private: + + //// WebRtcProxyPlayer override//// + void startConnect() override; + + //// PlayerBase override//// + void onResult(const toolkit::SockException &ex) override; + std::vector getTracks(bool ready = true) const override; + + //// TrackListener override//// + bool addTrack(const Track::Ptr &track) override { return true; } + void addTrackCompleted() override; +}; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_WEBRTC_PROXY_PLAYER_IMP_H */ diff --git a/webrtc/WebRtcProxyPusher.cpp b/webrtc/WebRtcProxyPusher.cpp new file mode 100755 index 00000000..e67b1923 --- /dev/null +++ b/webrtc/WebRtcProxyPusher.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcProxyPusher.h" +#include "Common/config.h" +#include "Http/HlsPlayer.h" +#include "Rtsp/RtspMediaSourceImp.h" +#include "WebRtcPlayer.h" + +using namespace toolkit; +using namespace std; + +namespace mediakit { + +WebRtcProxyPusher::WebRtcProxyPusher(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src) + : WebRtcClient(poller) { + _push_src = src; + DebugL; +} + +WebRtcProxyPusher::~WebRtcProxyPusher(void) { + teardown(); + DebugL; +} + +void WebRtcProxyPusher::publish(const string &strUrl) { + DebugL; + try { + _url.parse(strUrl, isPlayer()); + } catch (std::exception &ex) { + onResult(SockException(Err_other, StrPrinter << "illegal webrtc url:" << ex.what())); + return; + } + + startConnect(); +} + +void WebRtcProxyPusher::teardown() { + DebugL; + _transport = nullptr; +} + +void WebRtcProxyPusher::onResult(const SockException &ex) { + DebugL << ex; + if (!ex) { + onPublishResult(ex); + } else { + if (!_is_negotiate_finished) { + onPublishResult(ex); + } else { + onShutdown(ex); + } + } +} + +float WebRtcProxyPusher::getTimeOutSec() { + auto timeoutMS = (*this)[Client::kTimeoutMS].as(); + return (float)timeoutMS / (float)1000; +} + +void WebRtcProxyPusher::startConnect() { + DebugL; + MediaInfo info(_url._full_url); + info.schema = "rtc"; + auto src = _push_src.lock(); + if (!src) { + onResult(SockException(Err_other, "media source released")); + return; + } + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + _transport = WebRtcPlayer::create(getPoller(), src, info, WebRtcTransport::Role::CLIENT, _url._signaling_protocols); + _transport->setOnShutdown([weak_self](const SockException &ex) { + if (auto strong_self = weak_self.lock()) { + strong_self->onResult(ex); + } + }); + _transport->setOnStartWebRTC([weak_self]() { + if (auto strong_self = weak_self.lock()) { + strong_self->onResult(SockException()); + } + }); + WebRtcClient::startConnect(); +} + +} /* namespace mediakit */ diff --git a/webrtc/WebRtcProxyPusher.h b/webrtc/WebRtcProxyPusher.h new file mode 100755 index 00000000..7ec84309 --- /dev/null +++ b/webrtc/WebRtcProxyPusher.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTC_PROXY_PUSHER_H +#define ZLMEDIAKIT_WEBRTC_PROXY_PUSHER_H + +#include "Network/Socket.h" +#include "Pusher/PusherBase.h" +#include "Poller/Timer.h" +#include "Util/TimeTicker.h" +#include "WebRtcClient.h" +#include +#include + +namespace mediakit { + +// 实现了webrtc代理拉流功能 +class WebRtcProxyPusher + : public PusherBase , public WebRtcClient { +public: + using Ptr = std::shared_ptr; + + WebRtcProxyPusher(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src); + ~WebRtcProxyPusher() override; + + //// PusherBase override//// + void publish(const std::string &url) override; + void teardown() override; + + size_t getSendSpeed() override { return getWebRtcTransport() ? getWebRtcTransport()->getSendSpeed() : 0; } + size_t getSendTotalBytes() override { return getWebRtcTransport() ? getWebRtcTransport()->getSendTotalBytes() : 0; } + +protected: + //// WebRtcClient override//// + void startConnect() override; + bool isPlayer() override { return false; } + void onResult(const toolkit::SockException &ex) override; + float getTimeOutSec() override; + +protected: + std::weak_ptr _push_src; +}; + +using WebRtcProxyPusherImp = PusherImp; + +} /* namespace mediakit */ +#endif /* ZLMEDIAKIT_WEBRTC_PROXY_PUSHER_H */ diff --git a/webrtc/WebRtcPusher.cpp b/webrtc/WebRtcPusher.cpp index 2d47440c..85a0da76 100644 --- a/webrtc/WebRtcPusher.cpp +++ b/webrtc/WebRtcPusher.cpp @@ -13,6 +13,7 @@ #include "RtcMediaSource.h" using namespace std; +using namespace toolkit; namespace mediakit { @@ -20,13 +21,18 @@ WebRtcPusher::Ptr WebRtcPusher::create(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const std::shared_ptr &ownership, const MediaInfo &info, - const ProtocolOption &option) { - WebRtcPusher::Ptr ret(new WebRtcPusher(poller, src, ownership, info, option), [](WebRtcPusher *ptr) { + const ProtocolOption &option, + WebRtcTransport::Role role, + WebRtcTransport::SignalingProtocols signaling_protocols) { + WebRtcPusher::Ptr pusher(new WebRtcPusher(poller, src, ownership, info, option), [](WebRtcPusher *ptr) { ptr->onDestory(); delete ptr; }); - ret->onCreate(); - return ret; + + pusher->setRole(role); + pusher->setSignalingProtocols(signaling_protocols); + pusher->onCreate(); + return pusher; } WebRtcPusher::WebRtcPusher(const EventPoller::Ptr &poller, @@ -42,19 +48,10 @@ WebRtcPusher::WebRtcPusher(const EventPoller::Ptr &poller, } bool WebRtcPusher::close(MediaSource &sender) { - // 此回调在其他线程触发 [AUTO-TRANSLATED:c98e7686] - // This callback is triggered in another thread - string err = StrPrinter << "close media: " << sender.getUrl(); - weak_ptr weak_self = static_pointer_cast(shared_from_this()); - getPoller()->async([weak_self, err]() { - auto strong_self = weak_self.lock(); - if (strong_self) { - strong_self->onShutdown(SockException(Err_shutdown, err)); - // 主动关闭推流,那么不延时注销 [AUTO-TRANSLATED:ee7cc580] - // Actively close the stream, then do not delay the logout - strong_self->_push_src = nullptr; - } - }); + onShutdown(SockException(Err_shutdown, "close media: " + sender.getUrl())); + // 主动关闭推流,那么不延时注销 [AUTO-TRANSLATED:ee7cc580] + // Actively close the stream, then do not delay the logout + _push_src = nullptr; return true; } @@ -117,7 +114,7 @@ void WebRtcPusher::onRecvRtp(MediaTrack &track, const string &rid, RtpPacket::Pt void WebRtcPusher::onStartWebRTC() { WebRtcTransportImp::onStartWebRTC(); _simulcast = _answer_sdp->supportSimulcast(); - if (canRecvRtp()) { + if (canRecvRtp() && _push_src) { _push_src->setSdp(_answer_sdp->toRtspSdp()); } } @@ -175,4 +172,61 @@ void WebRtcPusher::onShutdown(const SockException &ex) { WebRtcTransportImp::onShutdown(ex); } -}// namespace mediakit \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////////////// + +WebRtcPlayerClient::Ptr WebRtcPlayerClient::create(const EventPoller::Ptr &poller, WebRtcTransport::Role role, + WebRtcTransport::SignalingProtocols signaling_protocols) { + WebRtcPlayerClient::Ptr pusher(new WebRtcPlayerClient(poller), [](WebRtcPlayerClient *ptr) { + ptr->onDestory(); + delete ptr; + }); + + pusher->setRole(role); + pusher->setSignalingProtocols(signaling_protocols); + pusher->onCreate(); + return pusher; +} + +WebRtcPlayerClient::WebRtcPlayerClient(const EventPoller::Ptr &poller) + : WebRtcTransportImp(poller) { + _demuxer = std::make_shared(); +} + +void WebRtcPlayerClient::onRecvRtp(MediaTrack &track, const string &rid, RtpPacket::Ptr rtp) { + auto key_pos = _demuxer->inputRtp(rtp); + if (_push_src) { + _push_src->onWrite(rtp, key_pos); + } +} + +void WebRtcPlayerClient::onStartWebRTC() { + WebRtcTransportImp::onStartWebRTC(); + CHECK(!_answer_sdp->supportSimulcast()); + auto sdp = _answer_sdp->toRtspSdp(); + if (canRecvRtp()) { + if (_push_src) { + _push_src->setSdp(sdp); + } + _demuxer->loadSdp(sdp); + } +} + +void WebRtcPlayerClient::onRtcConfigure(RtcConfigure &configure) const { + WebRtcTransportImp::onRtcConfigure(configure); + // 这只是推流 [AUTO-TRANSLATED:f877bf98] + // This is just pushing the stream + configure.audio.direction = configure.video.direction = RtpDirection::recvonly; +} + +vector WebRtcPlayerClient::getTracks(bool ready) const { + return _demuxer->getTracks(ready); +} + +void WebRtcPlayerClient::setMediaSource(RtspMediaSource::Ptr src) { + _push_src = std::move(src); + if (_push_src && canRecvRtp()) { + _push_src->setSdp(_answer_sdp->toRtspSdp()); + } +} + +}// namespace mediakit diff --git a/webrtc/WebRtcPusher.h b/webrtc/WebRtcPusher.h index 3cf3f94b..93aacbca 100644 --- a/webrtc/WebRtcPusher.h +++ b/webrtc/WebRtcPusher.h @@ -1,87 +1,113 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef ZLMEDIAKIT_WEBRTCPUSHER_H -#define ZLMEDIAKIT_WEBRTCPUSHER_H - -#include "WebRtcTransport.h" -#include "Rtsp/RtspMediaSource.h" - -namespace mediakit { - -class WebRtcPusher : public WebRtcTransportImp, public MediaSourceEvent { -public: - using Ptr = std::shared_ptr; - static Ptr create(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, - const std::shared_ptr &ownership, const MediaInfo &info, const ProtocolOption &option); - -protected: - ///////WebRtcTransportImp override/////// - void onStartWebRTC() override; - void onDestory() override; - void onRtcConfigure(RtcConfigure &configure) const override; - void onRecvRtp(MediaTrack &track, const std::string &rid, RtpPacket::Ptr rtp) override; - void onShutdown(const SockException &ex) override; - void onRtcpBye() override; - // // dtls相关的回调 //// [AUTO-TRANSLATED:31a1f32c] - // // dtls related callbacks //// - void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override; - -protected: - ///////MediaSourceEvent override/////// - // 关闭 [AUTO-TRANSLATED:92392f02] - // Close - bool close(MediaSource &sender) override; - // 播放总人数 [AUTO-TRANSLATED:c42a3161] - // Total number of players - int totalReaderCount(MediaSource &sender) override; - // 获取媒体源类型 [AUTO-TRANSLATED:34290a69] - // Get media source type - MediaOriginType getOriginType(MediaSource &sender) const override; - // 获取媒体源url或者文件路径 [AUTO-TRANSLATED:fa34d795] - // Get media source url or file path - std::string getOriginUrl(MediaSource &sender) const override; - // 获取媒体源客户端相关信息 [AUTO-TRANSLATED:037ef910] - // Get media source client related information - std::shared_ptr getOriginSock(MediaSource &sender) const override; - // 由于支持断连续推,存在OwnerPoller变更的可能 [AUTO-TRANSLATED:1c863b40] - // Due to support for discontinuous pushing, there is a possibility of OwnerPoller changes - toolkit::EventPoller::Ptr getOwnerPoller(MediaSource &sender) override; - // 获取丢包率 [AUTO-TRANSLATED:ec61b378] - // Get packet loss rate - float getLossRate(MediaSource &sender,TrackType type) override; - -private: - WebRtcPusher(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, - const std::shared_ptr &ownership, const MediaInfo &info, const ProtocolOption &option); - -private: - bool _simulcast = false; - // 断连续推延时 [AUTO-TRANSLATED:13ad578a] - // Discontinuous pushing delay - uint32_t _continue_push_ms = 0; - // 媒体相关元数据 [AUTO-TRANSLATED:f4cf8045] - // Media related metadata - MediaInfo _media_info; - // 推流的rtsp源 [AUTO-TRANSLATED:4f976bca] - // Rtsp source of the stream - RtspMediaSource::Ptr _push_src; - // 推流所有权 [AUTO-TRANSLATED:d0ddf5c7] - // Stream ownership - std::shared_ptr _push_src_ownership; - // 推流的rtsp源,支持simulcast [AUTO-TRANSLATED:44be9120] - // Rtsp source of the stream, supports simulcast - std::recursive_mutex _mtx; - std::unordered_map _push_src_sim; - std::unordered_map > _push_src_sim_ownership; -}; - -}// namespace mediakit -#endif //ZLMEDIAKIT_WEBRTCPUSHER_H +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTCPUSHER_H +#define ZLMEDIAKIT_WEBRTCPUSHER_H + +#include "WebRtcTransport.h" +#include "Rtsp/RtspDemuxer.h" +#include "Rtsp/RtspMediaSource.h" + +namespace mediakit { + +class WebRtcPusher : public WebRtcTransportImp, public MediaSourceEvent { +public: + using Ptr = std::shared_ptr; + static Ptr create(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, + const std::shared_ptr &ownership, const MediaInfo &info, const ProtocolOption &option, + WebRtcTransport::Role role, WebRtcTransport::SignalingProtocols signaling_protocols); + +protected: + ///////WebRtcTransportImp override/////// + void onStartWebRTC() override; + void onDestory() override; + void onRtcConfigure(RtcConfigure &configure) const override; + void onRecvRtp(MediaTrack &track, const std::string &rid, RtpPacket::Ptr rtp) override; + void onShutdown(const toolkit::SockException &ex) override; + void onRtcpBye() override; + // // dtls相关的回调 //// [AUTO-TRANSLATED:31a1f32c] + // // dtls related callbacks //// + void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override; + +protected: + ///////MediaSourceEvent override/////// + // 关闭 [AUTO-TRANSLATED:92392f02] + // Close + bool close(MediaSource &sender) override; + // 播放总人数 [AUTO-TRANSLATED:c42a3161] + // Total number of players + int totalReaderCount(MediaSource &sender) override; + // 获取媒体源类型 [AUTO-TRANSLATED:34290a69] + // Get media source type + MediaOriginType getOriginType(MediaSource &sender) const override; + // 获取媒体源url或者文件路径 [AUTO-TRANSLATED:fa34d795] + // Get media source url or file path + std::string getOriginUrl(MediaSource &sender) const override; + // 获取媒体源客户端相关信息 [AUTO-TRANSLATED:037ef910] + // Get media source client related information + std::shared_ptr getOriginSock(MediaSource &sender) const override; + // 由于支持断连续推,存在OwnerPoller变更的可能 [AUTO-TRANSLATED:1c863b40] + // Due to support for discontinuous pushing, there is a possibility of OwnerPoller changes + toolkit::EventPoller::Ptr getOwnerPoller(MediaSource &sender) override; + // 获取丢包率 [AUTO-TRANSLATED:ec61b378] + // Get packet loss rate + float getLossRate(MediaSource &sender,TrackType type) override; + +private: + WebRtcPusher(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, + const std::shared_ptr &ownership, const MediaInfo &info, const ProtocolOption &option); + +private: + bool _simulcast = false; + // 断连续推延时 [AUTO-TRANSLATED:13ad578a] + // Discontinuous pushing delay + uint32_t _continue_push_ms = 0; + // 媒体相关元数据 [AUTO-TRANSLATED:f4cf8045] + // Media related metadata + MediaInfo _media_info; + // 推流的rtsp源 [AUTO-TRANSLATED:4f976bca] + // Rtsp source of the stream + RtspMediaSource::Ptr _push_src; + // 推流所有权 [AUTO-TRANSLATED:d0ddf5c7] + // Stream ownership + std::shared_ptr _push_src_ownership; + // 推流的rtsp源,支持simulcast [AUTO-TRANSLATED:44be9120] + // Rtsp source of the stream, supports simulcast + std::recursive_mutex _mtx; + std::unordered_map _push_src_sim; + std::unordered_map > _push_src_sim_ownership; +}; + +class WebRtcPlayerClient : public WebRtcTransportImp { +public: + using Ptr = std::shared_ptr; + static Ptr create(const toolkit::EventPoller::Ptr &poller, WebRtcTransport::Role role, WebRtcTransport::SignalingProtocols signaling_protocols); + + void setMediaSource(RtspMediaSource::Ptr src); + std::vector getTracks(bool ready) const; + +protected: + ///////WebRtcTransportImp override/////// + void onStartWebRTC() override; + void onRtcConfigure(RtcConfigure &configure) const override; + void onRecvRtp(MediaTrack &track, const std::string &rid, RtpPacket::Ptr rtp) override; + +private: + WebRtcPlayerClient(const toolkit::EventPoller::Ptr &poller); + +private: + RtspDemuxer::Ptr _demuxer; + // 推流的rtsp源 [AUTO-TRANSLATED:4f976bca] + // Rtsp source of the stream + RtspMediaSource::Ptr _push_src; +}; + +}// namespace mediakit +#endif //ZLMEDIAKIT_WEBRTCPUSHER_H diff --git a/webrtc/WebRtcSession.cpp b/webrtc/WebRtcSession.cpp index d5fb06e9..bbbedd92 100644 --- a/webrtc/WebRtcSession.cpp +++ b/webrtc/WebRtcSession.cpp @@ -1,166 +1,164 @@ -/* - * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. - * - * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). - * - * Use of this source code is governed by MIT-like license that can be found in the - * LICENSE file in the root of the source tree. All contributing project authors - * may be found in the AUTHORS file in the root of the source tree. - */ - -#include "WebRtcSession.h" -#include "Util/util.h" -#include "Network/TcpServer.h" -#include "Common/config.h" -#include "IceServer.hpp" -#include "WebRtcTransport.h" - -using namespace std; - -namespace mediakit { - -static string getUserName(const char *buf, size_t len) { - if (!RTC::StunPacket::IsStun((const uint8_t *) buf, len)) { - return ""; - } - std::unique_ptr packet(RTC::StunPacket::Parse((const uint8_t *) buf, len)); - if (!packet) { - return ""; - } - if (packet->GetClass() != RTC::StunPacket::Class::REQUEST || - packet->GetMethod() != RTC::StunPacket::Method::BINDING) { - return ""; - } - // 收到binding request请求 [AUTO-TRANSLATED:eff4d773] - // Received binding request - auto vec = split(packet->GetUsername(), ":"); - return vec[0]; -} - -EventPoller::Ptr WebRtcSession::queryPoller(const Buffer::Ptr &buffer) { - auto user_name = getUserName(buffer->data(), buffer->size()); - if (user_name.empty()) { - return nullptr; - } - auto ret = WebRtcTransportManager::Instance().getItem(user_name); - return ret ? ret->getPoller() : nullptr; -} - -//////////////////////////////////////////////////////////////////////////////// - -WebRtcSession::WebRtcSession(const Socket::Ptr &sock) : Session(sock) { - _over_tcp = sock->sockType() == SockNum::Sock_TCP; -} - -void WebRtcSession::attachServer(const Server &server) { - _server = std::static_pointer_cast(const_cast(server).shared_from_this()); -} - -void WebRtcSession::onRecv_l(const char *data, size_t len) { - if (_find_transport) { - // 只允许寻找一次transport [AUTO-TRANSLATED:446fae53] - // Only allow searching for transport once - _find_transport = false; - auto user_name = getUserName(data, len); - auto transport = WebRtcTransportManager::Instance().getItem(user_name); - CHECK(transport); - - // WebRtcTransport在其他poller线程上,需要切换poller线程并重新创建WebRtcSession对象 [AUTO-TRANSLATED:7e5534cf] - // WebRtcTransport is on another poller thread, need to switch poller thread and recreate WebRtcSession object - if (!transport->getPoller()->isCurrentThread()) { - auto sock = Socket::createSocket(transport->getPoller(), false); - // 1、克隆socket(fd不变),切换poller线程到WebRtcTransport所在线程 [AUTO-TRANSLATED:f930bfab] - // 1. Clone socket (fd remains unchanged), switch poller thread to the thread where WebRtcTransport is located - sock->cloneSocket(*(getSock())); - auto server = _server; - std::string str(data, len); - sock->getPoller()->async([sock, server, str](){ - auto strong_server = server.lock(); - if (strong_server) { - auto session = static_pointer_cast(strong_server->createSession(sock)); - // 2、创建新的WebRtcSession对象(绑定到WebRtcTransport所在线程),重新处理一遍ice binding request命令 [AUTO-TRANSLATED:c75203bb] - // 2. Create a new WebRtcSession object (bound to the thread where WebRtcTransport is located), reprocess the ice binding request command - session->onRecv_l(str.data(), str.size()); - } - }); - // 3、销毁原先的socket和WebRtcSession(原先的对象跟WebRtcTransport不在同一条线程) [AUTO-TRANSLATED:a6d6d63f] - // 3. Destroy the original socket and WebRtcSession (the original object is not on the same thread as WebRtcTransport) - throw std::runtime_error("webrtc over tcp change poller: " + getPoller()->getThreadName() + " -> " + sock->getPoller()->getThreadName()); - } - _transport = std::move(transport); - InfoP(this); - } - _ticker.resetTime(); - CHECK(_transport); - _transport->inputSockData((char *)data, len, this); -} - -void WebRtcSession::onRecv(const Buffer::Ptr &buffer) { - if (_over_tcp) { - input(buffer->data(), buffer->size()); - } else { - onRecv_l(buffer->data(), buffer->size()); - } -} - -void WebRtcSession::onError(const SockException &err) { - // udp链接超时,但是rtc链接不一定超时,因为可能存在链接迁移的情况 [AUTO-TRANSLATED:aaa9672f] - // UDP connection timeout, but RTC connection may not timeout, because there may be connection migration - // 在udp链接迁移时,新的WebRtcSession对象将接管WebRtcTransport对象的生命周期 [AUTO-TRANSLATED:7e7d19df] - // When UDP connection migrates, the new WebRtcSession object will take over the life cycle of the WebRtcTransport object - // 本WebRtcSession对象将在超时后自动销毁 [AUTO-TRANSLATED:bc903a06] - // This WebRtcSession object will be automatically destroyed after timeout - WarnP(this) << err; - - if (!_transport) { - return; - } - auto self = static_pointer_cast(shared_from_this()); - auto transport = std::move(_transport); - getPoller()->async([transport, self]() mutable { - // 延时减引用,防止使用transport对象时,销毁对象 [AUTO-TRANSLATED:09dd6609] - // Delay decrementing the reference count to prevent the object from being destroyed when using the transport object - transport->removeTuple(self.get()); - // 确保transport在Session对象前销毁,防止WebRtcTransport::onDestory()时获取不到Session对象 [AUTO-TRANSLATED:acd8bd77] - // Ensure that the transport is destroyed before the Session object to prevent WebRtcTransport::onDestory() from not being able to get the Session object - transport = nullptr; - }, false); -} - -void WebRtcSession::onManager() { - GET_CONFIG(float, timeoutSec, Rtc::kTimeOutSec); - if (!_transport && _ticker.createdTime() > timeoutSec * 1000) { - shutdown(SockException(Err_timeout, "illegal webrtc connection")); - return; - } - if (_ticker.elapsedTime() > timeoutSec * 1000) { - shutdown(SockException(Err_timeout, "webrtc connection timeout")); - return; - } -} - -ssize_t WebRtcSession::onRecvHeader(const char *data, size_t len) { - onRecv_l(data + 2, len - 2); - return 0; -} - -const char *WebRtcSession::onSearchPacketTail(const char *data, size_t len) { - if (len < 2) { - // 数据不够 [AUTO-TRANSLATED:830a2785] - // Not enough data - return nullptr; - } - uint16_t length = (((uint8_t *)data)[0] << 8) | ((uint8_t *)data)[1]; - if (len < (size_t)(length + 2)) { - // 数据不够 [AUTO-TRANSLATED:830a2785] - // Not enough data - return nullptr; - } - // 返回rtp包末尾 [AUTO-TRANSLATED:5134cf6f] - // Return the end of the RTP packet - return data + 2 + length; -} - -}// namespace mediakit - - +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcSession.h" +#include "Util/util.h" +#include "Network/TcpServer.h" +#include "Common/config.h" +#include "IceTransport.hpp" +#include "WebRtcTransport.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +static string getUserName(const char *buf, size_t len) { + if (!RTC::StunPacket::isStun((const uint8_t *) buf, len)) { + return ""; + } + auto packet = RTC::StunPacket::parse((const uint8_t *) buf, len); + if (!packet) { + return ""; + } + + // 收到binding request请求 [AUTO-TRANSLATED:eff4d773] + // Received binding request + auto vec = split(packet->getUsername(), ":"); + return vec[0]; +} + +EventPoller::Ptr WebRtcSession::queryPoller(const Buffer::Ptr &buffer) { + auto user_name = getUserName(buffer->data(), buffer->size()); + if (user_name.empty()) { + return nullptr; + } + auto ret = WebRtcTransportManager::Instance().getItem(user_name); + return ret ? ret->getPoller() : nullptr; +} + +//////////////////////////////////////////////////////////////////////////////// + +WebRtcSession::WebRtcSession(const Socket::Ptr &sock) : Session(sock) { + _over_tcp = sock->sockType() == SockNum::Sock_TCP; +} + +void WebRtcSession::attachServer(const Server &server) { + _server = std::static_pointer_cast(const_cast(server).shared_from_this()); +} + +void WebRtcSession::onRecv_l(const char *data, size_t len) { + if (_find_transport) { + // 只允许寻找一次transport [AUTO-TRANSLATED:446fae53] + // Only allow searching for transport once + _find_transport = false; + auto user_name = getUserName(data, len); + auto transport = WebRtcTransportManager::Instance().getItem(user_name); + CHECK(transport); + + // WebRtcTransport在其他poller线程上,需要切换poller线程并重新创建WebRtcSession对象 [AUTO-TRANSLATED:7e5534cf] + // WebRtcTransport is on another poller thread, need to switch poller thread and recreate WebRtcSession object + if (!transport->getPoller()->isCurrentThread()) { + auto sock = Socket::createSocket(transport->getPoller(), false); + // 1、克隆socket(fd不变),切换poller线程到WebRtcTransport所在线程 [AUTO-TRANSLATED:f930bfab] + // 1. Clone socket (fd remains unchanged), switch poller thread to the thread where WebRtcTransport is located + auto on_complete = sock->cloneSocket(*(getSock())); + auto server = _server; + std::string str(data, len); + // on_complete在创建WebRtcSession后才析构(才开始网络事件监听) + sock->getPoller()->async([sock, server, str, on_complete](){ + auto strong_server = server.lock(); + if (strong_server) { + auto session = static_pointer_cast(strong_server->createSession(sock)); + // 2、创建新的WebRtcSession对象(绑定到WebRtcTransport所在线程),重新处理一遍ice binding request命令 [AUTO-TRANSLATED:c75203bb] + // 2. Create a new WebRtcSession object (bound to the thread where WebRtcTransport is located), reprocess the ice binding request command + session->onRecv_l(str.data(), str.size()); + } + }); + // 3、销毁原先的socket和WebRtcSession(原先的对象跟WebRtcTransport不在同一条线程) [AUTO-TRANSLATED:a6d6d63f] + // 3. Destroy the original socket and WebRtcSession (the original object is not on the same thread as WebRtcTransport) + throw std::runtime_error("webrtc over tcp change poller: " + getPoller()->getThreadName() + " -> " + sock->getPoller()->getThreadName()); + } + _transport = std::move(transport); + InfoP(this); + } + _ticker.resetTime(); + CHECK(_transport); + auto self = static_pointer_cast(shared_from_this()); + _transport->inputSockData(data, len, self); +} + +void WebRtcSession::onRecv(const Buffer::Ptr &buffer) { + if (_over_tcp) { + input(buffer->data(), buffer->size()); + } else { + onRecv_l(buffer->data(), buffer->size()); + } +} + +void WebRtcSession::onError(const SockException &err) { + // udp链接超时,但是rtc链接不一定超时,因为可能存在链接迁移的情况 [AUTO-TRANSLATED:aaa9672f] + // UDP connection timeout, but RTC connection may not timeout, because there may be connection migration + // 在udp链接迁移时,新的WebRtcSession对象将接管WebRtcTransport对象的生命周期 [AUTO-TRANSLATED:7e7d19df] + // When UDP connection migrates, the new WebRtcSession object will take over the life cycle of the WebRtcTransport object + // 本WebRtcSession对象将在超时后自动销毁 [AUTO-TRANSLATED:bc903a06] + // This WebRtcSession object will be automatically destroyed after timeout + WarnP(this) << err; + + if (!_transport) { + return; + } + auto self = static_pointer_cast(shared_from_this()); + auto transport = std::move(_transport); + getPoller()->async([transport, self]() mutable { + // 延时减引用,防止使用transport对象时,销毁对象 [AUTO-TRANSLATED:09dd6609] + // Delay decrementing the reference count to prevent the object from being destroyed when using the transport object + transport->removePair(self.get()); + // 确保transport在Session对象前销毁,防止WebRtcTransport::onDestory()时获取不到Session对象 [AUTO-TRANSLATED:acd8bd77] + // Ensure that the transport is destroyed before the Session object to prevent WebRtcTransport::onDestory() from not being able to get the Session object + transport = nullptr; + }, false); +} + +void WebRtcSession::onManager() { + GET_CONFIG(float, timeoutSec, Rtc::kTimeOutSec); + if (!_transport && _ticker.createdTime() > timeoutSec * 1000) { + shutdown(SockException(Err_timeout, "illegal webrtc connection")); + return; + } + if (_ticker.elapsedTime() > timeoutSec * 1000) { + shutdown(SockException(Err_timeout, "webrtc connection timeout")); + return; + } +} + +ssize_t WebRtcSession::onRecvHeader(const char *data, size_t len) { + onRecv_l(data + 2, len - 2); + return 0; +} + +const char *WebRtcSession::onSearchPacketTail(const char *data, size_t len) { + if (len < 2) { + // 数据不够 [AUTO-TRANSLATED:830a2785] + // Not enough data + return nullptr; + } + uint16_t length = (((uint8_t *)data)[0] << 8) | ((uint8_t *)data)[1]; + if (len < (size_t)(length + 2)) { + // 数据不够 [AUTO-TRANSLATED:830a2785] + // Not enough data + return nullptr; + } + // 返回rtp包末尾 [AUTO-TRANSLATED:5134cf6f] + // Return the end of the RTP packet + return data + 2 + length; +} + +}// namespace mediakit diff --git a/webrtc/WebRtcSession.h b/webrtc/WebRtcSession.h index 3bfd96bc..7afbb2c4 100644 --- a/webrtc/WebRtcSession.h +++ b/webrtc/WebRtcSession.h @@ -21,18 +21,18 @@ namespace toolkit { } namespace mediakit { + class WebRtcTransportImp; -using namespace toolkit; -class WebRtcSession : public Session, public HttpRequestSplitter { +class WebRtcSession : public toolkit::Session, public HttpRequestSplitter { public: - WebRtcSession(const Socket::Ptr &sock); + WebRtcSession(const toolkit::Socket::Ptr &sock); - void attachServer(const Server &server) override; - void onRecv(const Buffer::Ptr &) override; - void onError(const SockException &err) override; + void attachServer(const toolkit::Server &server) override; + void onRecv(const toolkit::Buffer::Ptr &) override; + void onError(const toolkit::SockException &err) override; void onManager() override; - static EventPoller::Ptr queryPoller(const Buffer::Ptr &buffer); + static toolkit::EventPoller::Ptr queryPoller(const toolkit::Buffer::Ptr &buffer); protected: WebRtcTransportImp::Ptr _transport; @@ -47,7 +47,7 @@ private: private: bool _over_tcp = false; bool _find_transport = true; - Ticker _ticker; + toolkit::Ticker _ticker; std::weak_ptr _server; }; diff --git a/webrtc/WebRtcSignalingMsg.cpp b/webrtc/WebRtcSignalingMsg.cpp new file mode 100644 index 00000000..230f285f --- /dev/null +++ b/webrtc/WebRtcSignalingMsg.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcSignalingMsg.h" + +namespace mediakit { +namespace Rtc { + +// WebRTC 信令消息键名和值常量定义 +const char* const CLASS_KEY = "class"; +const char* const CLASS_VALUE_REQUEST = "request"; +const char* const CLASS_VALUE_INDICATION = "indication"; // 指示类型,不需要应答 +const char* const CLASS_VALUE_ACCEPT = "accept"; // 作为CLASS_VALUE_REQUEST的应答 +const char* const CLASS_VALUE_REJECT = "reject"; // 作为CLASS_VALUE_REQUEST的应答 +const char* const METHOD_KEY = "method"; +const char* const METHOD_VALUE_REGISTER = "register"; // 注册 +const char* const METHOD_VALUE_UNREGISTER = "unregister"; // 注销 +const char* const METHOD_VALUE_CALL = "call"; // 呼叫(取流或推流) + +const char* const METHOD_VALUE_BYE = "bye"; // 挂断 +const char* const METHOD_VALUE_CANDIDATE = "candidate"; +const char* const TRANSACTION_ID_KEY = "transaction_id"; // 消息id,每条消息拥有一个唯一的id +const char* const ROOM_ID_KEY = "room_id"; +const char* const GUEST_ID_KEY = "guest_id"; // 每个独立的会话,会拥有一个唯一的guest_id +const char* const SENDER_KEY = "sender"; +const char* const TYPE_KEY = "type"; +const char* const TYPE_VALUE_PLAY = "play"; // 拉流 +const char* const TYPE_VALUE_PUSH = "push"; // 推流 +const char* const REASON_KEY = "reason"; +const char* const CALL_VHOST_KEY = "vhost"; +const char* const CALL_APP_KEY = "app"; +const char* const CALL_STREAM_KEY = "stream"; +const char* const SDP_KEY = "sdp"; + +const char* const ICE_SERVERS_KEY = "ice_servers"; +const char* const CANDIDATE_KEY = "candidate"; +const char* const URL_KEY = "url"; +const char* const UFRAG_KEY = "ufrag"; +const char* const PWD_KEY = "pwd"; + +} // namespace Rtc +} // namespace mediakit diff --git a/webrtc/WebRtcSignalingMsg.h b/webrtc/WebRtcSignalingMsg.h new file mode 100644 index 00000000..89fed39e --- /dev/null +++ b/webrtc/WebRtcSignalingMsg.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + + +#ifndef ZLMEDIAKIT_WEBRTC_SIGNALING_MSG_H +#define ZLMEDIAKIT_WEBRTC_SIGNALING_MSG_H + +#include "server/WebApi.h" + +namespace mediakit { +namespace Rtc { + +#define SIGNALING_MSG_ARGS const HttpAllArgs& allArgs + +// WebRTC 信令消息键名和值常量声明 +extern const char* const CLASS_KEY; +extern const char* const CLASS_VALUE_REQUEST; +extern const char* const CLASS_VALUE_INDICATION; // 指示类型,不需要应答 +extern const char* const CLASS_VALUE_ACCEPT; // 作为CLASS_VALUE_REQUEST的应答 +extern const char* const CLASS_VALUE_REJECT; // 作为CLASS_VALUE_REQUEST的应答 +extern const char* const METHOD_KEY; +extern const char* const METHOD_VALUE_REGISTER; // 注册 +extern const char* const METHOD_VALUE_UNREGISTER; // 注销 +extern const char* const METHOD_VALUE_CALL; // 呼叫(取流或推流) + +extern const char* const METHOD_VALUE_BYE; // 挂断 +extern const char* const METHOD_VALUE_CANDIDATE; +extern const char* const TRANSACTION_ID_KEY; // 消息id,每条消息拥有一个唯一的id +extern const char* const ROOM_ID_KEY; +extern const char* const GUEST_ID_KEY; // 每个独立的会话,会拥有一个唯一的guest_id +extern const char* const SENDER_KEY; +extern const char* const TYPE_KEY; +extern const char* const TYPE_VALUE_PLAY; // 拉流 +extern const char* const TYPE_VALUE_PUSH; // 推流 +extern const char* const REASON_KEY; +extern const char* const CALL_VHOST_KEY; +extern const char* const CALL_APP_KEY; +extern const char* const CALL_STREAM_KEY; +extern const char* const SDP_KEY; + +extern const char* const ICE_SERVERS_KEY; +extern const char* const CANDIDATE_KEY; +extern const char* const URL_KEY; +extern const char* const UFRAG_KEY; +extern const char* const PWD_KEY; + +} // namespace Rtc +} // namespace mediakit +// + +#endif //ZLMEDIAKIT_WEBRTC_SIGNALING_PEER_H diff --git a/webrtc/WebRtcSignalingPeer.cpp b/webrtc/WebRtcSignalingPeer.cpp new file mode 100644 index 00000000..f224e820 --- /dev/null +++ b/webrtc/WebRtcSignalingPeer.cpp @@ -0,0 +1,687 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcSignalingPeer.h" +#include "WebRtcSignalingMsg.h" +#include "Util/util.h" +#include "Common/config.h" +#include "json/value.h" + +using namespace std; +using namespace toolkit; +using namespace mediakit::Rtc; + +namespace mediakit { + +// 注册到的信令服务器列表 +// 不允许注册到同一个服务器地址 +static ServiceController s_room_keepers; + +static inline string getRoomKeepersKey(const string &host, uint16_t &port) { + return host + ":" + std::to_string(port); +} + +void addWebrtcRoomKeeper(const string &host, uint16_t port, const std::string& room_id, bool ssl, + const function &cb) { + DebugL; + auto key = getRoomKeepersKey(host, port); + if (s_room_keepers.find(key)) { + //已经发起注册了 + cb(SockException(Err_success), key); + return; + } + auto peer = s_room_keepers.make(key, host, port, ssl, room_id); + peer->setOnShutdown([key] (const SockException &ex) { + InfoL << "webrtc peer shutdown, key: " << key << ", " << ex.what(); + s_room_keepers.erase(key); + }); + + peer->setOnConnect([peer, cb] (const SockException &ex) { + peer->regist(cb); + }); + peer->connect(); +} + +void delWebrtcRoomKeeper(const std::string &key, const std::function &cb) { + auto peer = s_room_keepers.find(key); + if (!peer) { + return cb(SockException(Err_other, "room_key not exist")); + } + peer->unregist(cb); + s_room_keepers.erase(key); +} + +void listWebrtcRoomKeepers(const std::function &cb) { + s_room_keepers.for_each(cb); +} + +Json::Value ToJson(const WebRtcSignalingPeer::Ptr& p) { + return p->makeInfoJson(); +} + +WebRtcSignalingPeer::Ptr getWebrtcRoomKeeper(const string &host, uint16_t port) { + return s_room_keepers.find(getRoomKeepersKey(host, port)); +} + +//////////// WebRtcSignalingPeer ////////////////////////// + +WebRtcSignalingPeer::WebRtcSignalingPeer(const std::string &host, uint16_t port, bool ssl, const std::string &room_id, const EventPoller::Ptr &poller) + : WebSocketClient(poller) + , _room_id(room_id) { + TraceL; + // TODO: not support wss now + _ws_url = StrPrinter << (ssl ? "wss://" : "ws://") + host << ":" << port << "/signaling"; + _room_key = getRoomKeepersKey(host, port); +} + +WebRtcSignalingPeer::~WebRtcSignalingPeer() { + DebugL << "room_id: " << _room_id; +} + +void WebRtcSignalingPeer::connect() { + DebugL; + startWebSocket(_ws_url); +} + +void WebRtcSignalingPeer::regist(const function &cb) { + DebugL; + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, cb]() mutable { + if (auto strong_self = weak_self.lock()) { + strong_self->sendRegisterRequest(std::move(cb)); + } + }); +} + +void WebRtcSignalingPeer::unregist(const function &cb) { + DebugL; + auto trigger = [cb](const SockException &ex, std::string msg) { cb(ex); }; + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, trigger]() mutable { + if (auto strong_self = weak_self.lock()) { + strong_self->sendUnregisterRequest(std::move(trigger)); + } + }); +} + +void WebRtcSignalingPeer::checkIn(const std::string& peer_room_id, const MediaTuple &tuple, const std::string& identifier, + const std::string& offer, bool is_play, + const function &cb, float timeout_sec) { + DebugL; + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([=] () mutable { + TraceL; + if (auto strong_self = weak_self.lock()) { + auto guest_id = strong_self->_room_id + "_" + makeRandStr(16); + strong_self->_tours.emplace(peer_room_id, std::make_pair(guest_id, identifier)); + auto trigger = ([cb, peer_room_id, weak_self](const SockException &ex, const std::string &msg) { + auto strong_self = weak_self.lock(); + if (ex && strong_self) { + strong_self->_tours.erase(peer_room_id); + } + return cb(ex, msg); + }); + strong_self->sendCallRequest(peer_room_id, guest_id, tuple, offer, is_play, std::move(trigger)); + } + }); +} + +void WebRtcSignalingPeer::checkOut(const std::string& peer_room_id) { + DebugL; + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([=] () { + TraceL; + if (auto strong_self = weak_self.lock()) { + auto it = strong_self->_tours.find(peer_room_id); + if (it != strong_self->_tours.end()) { + auto &guest_id = it->second.first; + strong_self->sendByeIndication(peer_room_id, guest_id); + strong_self->_tours.erase(it); + } + } + }); +} + +void WebRtcSignalingPeer::candidate(const std::string& transport_identifier, const std::string& candidate, const std::string& ice_ufrag, const std::string& ice_pwd) { + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([=] () { + if (auto strong_self = weak_self.lock()) { + strong_self->sendCandidateIndication(transport_identifier, candidate, ice_ufrag, ice_pwd); + } + }); +} + +void WebRtcSignalingPeer::processOffer(SIGNALING_MSG_ARGS, WebRtcInterface &transport) { + try { + auto sdp = transport.getAnswerSdp((const std::string )allArgs[SDP_KEY]); + auto tuple = MediaTuple(allArgs[CALL_VHOST_KEY], allArgs[CALL_APP_KEY], allArgs[CALL_STREAM_KEY]); + answer(allArgs[GUEST_ID_KEY], tuple, transport.getIdentifier(), sdp, allArgs[TYPE_KEY] == TYPE_VALUE_PLAY, allArgs[TRANSACTION_ID_KEY]); + + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + transport.gatheringCandidate(_ice_server, [weak_self](const std::string& transport_identifier, + const std::string& candidate, const std::string& ufrag, const std::string& pwd) { + if (auto strong_self = weak_self.lock()) { + strong_self->candidate(transport_identifier, candidate, ufrag, pwd); + } + }); + } catch (std::exception &ex) { + Json::Value body; + body[METHOD_KEY] = allArgs[METHOD_KEY]; + body[ROOM_ID_KEY] = allArgs[ROOM_ID_KEY]; + body[GUEST_ID_KEY] = allArgs[GUEST_ID_KEY]; + body[CALL_VHOST_KEY] = allArgs[CALL_VHOST_KEY]; + body[CALL_APP_KEY] = allArgs[CALL_APP_KEY]; + body[CALL_STREAM_KEY] = allArgs[CALL_STREAM_KEY]; + body[TYPE_KEY] = allArgs[TYPE_KEY]; + sendRefusesResponse(body, allArgs[TRANSACTION_ID_KEY], ex.what()); + } +} + +void WebRtcSignalingPeer::answer(const std::string& guest_id, const MediaTuple &tuple, const std::string& identifier, const std::string& sdp, bool is_play, const std::string& transaction_id) { + _peer_guests.emplace(guest_id, identifier); + sendCallAccept(guest_id, tuple, sdp, is_play, transaction_id); +} + +void WebRtcSignalingPeer::setOnConnect(function cb) { + _on_connect = cb ? std::move(cb) : [](const SockException &) {}; +} + +void WebRtcSignalingPeer::onConnect(const SockException &ex) { + TraceL; + if (_on_connect) { + _on_connect(ex); + } + if (!ex) { + createResponseExpiredTimer(); + } +} + +void WebRtcSignalingPeer::setOnShutdown(function cb) { + _on_shutdown = cb ? std::move(cb) : [](const SockException &) {}; +} + +void WebRtcSignalingPeer::onShutdown(const SockException &ex) { + TraceL; + if (_on_shutdown) { + _on_shutdown(ex); + } +} + +void WebRtcSignalingPeer::onRecv(const Buffer::Ptr &buffer) { + TraceL << "recv msg:\r\n" << buffer->data(); + + Json::Value args; + Json::Reader reader; + reader.parse(buffer->data(), args); + Parser parser; + HttpAllArgs allArgs(parser, args); + + CHECK_ARGS(METHOD_KEY, TRANSACTION_ID_KEY); + + using MsgHandler = void (WebRtcSignalingPeer::*)(SIGNALING_MSG_ARGS); + static std::unordered_map, MsgHandler, ClassMethodHash> s_msg_handlers; + + static onceToken token([]() { + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_ACCEPT, METHOD_VALUE_REGISTER), &WebRtcSignalingPeer::handleRegisterAccept); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REJECT, METHOD_VALUE_REGISTER), &WebRtcSignalingPeer::handleRegisterReject); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_ACCEPT, METHOD_VALUE_UNREGISTER), &WebRtcSignalingPeer::handleUnregisterAccept); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REJECT, METHOD_VALUE_UNREGISTER), &WebRtcSignalingPeer::handleUnregisterReject); + + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REQUEST, METHOD_VALUE_CALL), &WebRtcSignalingPeer::handleCallRequest); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_ACCEPT, METHOD_VALUE_CALL), &WebRtcSignalingPeer::handleCallAccept); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REJECT, METHOD_VALUE_CALL), &WebRtcSignalingPeer::handleCallReject); + + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_INDICATION, METHOD_VALUE_CANDIDATE), &WebRtcSignalingPeer::handleCandidateIndication); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_INDICATION, METHOD_VALUE_BYE), &WebRtcSignalingPeer::handleByeIndication); + }); + + auto it = s_msg_handlers.find(std::make_pair(allArgs[CLASS_KEY], allArgs[METHOD_KEY])); + if (it == s_msg_handlers.end()) { + WarnL << "unsupported class: "<< allArgs[CLASS_KEY] << ", method: " << allArgs[METHOD_KEY] << ", ignore"; + return; + } + return (this->*(it->second))(allArgs); +} + +void WebRtcSignalingPeer::onError(const SockException &err) { + WarnL << "room_id: " << _room_id; + s_room_keepers.erase(_room_key); + // 除非对端显式的发送了注销执行,否则因为网络异常导致的会话中断,不影响已经进行通信的webrtc会话,仅作移除 +} + +bool WebRtcSignalingPeer::responseFilter(SIGNALING_MSG_ARGS, ResponseTrigger& trigger) { + if (allArgs[CLASS_KEY] != CLASS_VALUE_ACCEPT && allArgs[CLASS_KEY] != CLASS_VALUE_REJECT) { + return false; + } + + for (auto &pr : _response_list) { + auto &transaction_id = pr.first; + // mismatch transaction_id + if (transaction_id != allArgs[TRANSACTION_ID_KEY] && !transaction_id.empty()) { + continue; + } + + auto &handle = pr.second; + if (allArgs[METHOD_KEY] != handle.method) { + WarnL << "recv response method: " << allArgs[METHOD_KEY] << " mismatch request method: " << handle.method; + return false; + } + + trigger = std::move(handle.cb); + _response_list.erase(transaction_id); + return true; + } + return false; +} + +void WebRtcSignalingPeer::sendRegisterRequest(ResponseTrigger trigger) { + TraceL; + Json::Value body; + body[CLASS_KEY] = CLASS_VALUE_REQUEST; + body[METHOD_KEY] = METHOD_VALUE_REGISTER; + body[ROOM_ID_KEY] = getRoomId(); + sendRequest(body, std::move(trigger)); +} + +void WebRtcSignalingPeer::handleRegisterAccept(SIGNALING_MSG_ARGS) { + TraceL; + ResponseTrigger trigger; + if (!responseFilter(allArgs, trigger)) { + return; + } + + auto jsonArgs = allArgs.getArgs(); + auto ice_servers = jsonArgs[ICE_SERVERS_KEY]; + if (ice_servers.type() != Json::ValueType::arrayValue) { + _StrPrinter msg; + msg << "illegal \"" << ICE_SERVERS_KEY << "\" point"; + WarnL << msg; + trigger(SockException(Err_other, msg), getRoomKey()); + return; + } + + if (ice_servers.empty()) { + _StrPrinter msg; + msg << "no ice server found in \"" << ICE_SERVERS_KEY << "\" point"; + WarnL << msg; + trigger(SockException(Err_other, msg), getRoomKey()); + return; + } + + for (auto &ice_server : ice_servers) { + // only support 1 ice_server now + auto url = ice_server[URL_KEY].asString(); + _ice_server = std::make_shared(url); + _ice_server->_ufrag = ice_server[UFRAG_KEY].asString(); + _ice_server->_pwd = ice_server[PWD_KEY].asString(); + } + + trigger(SockException(Err_success), getRoomKey()); +} + +void WebRtcSignalingPeer::handleRegisterReject(SIGNALING_MSG_ARGS) { + TraceL; + ResponseTrigger trigger; + if (!responseFilter(allArgs, trigger)) { + return; + } + + auto ex = SockException(Err_other, StrPrinter << "register refuses by server, reason: " << allArgs[REASON_KEY]); + trigger(ex, getRoomKey()); + onShutdown(ex); +} + +void WebRtcSignalingPeer::sendUnregisterRequest(ResponseTrigger trigger) { + TraceL; + Json::Value body; + body[CLASS_KEY] = CLASS_VALUE_REQUEST; + body[METHOD_KEY] = METHOD_VALUE_UNREGISTER; + body[ROOM_ID_KEY] = _room_id; + sendRequest(body, std::move(trigger)); +} + +void WebRtcSignalingPeer::handleUnregisterAccept(SIGNALING_MSG_ARGS) { + ResponseTrigger trigger; + if (!responseFilter(allArgs, trigger)) { + return; + } + + trigger(SockException(Err_success), getRoomKey()); +} + +void WebRtcSignalingPeer::handleUnregisterReject(SIGNALING_MSG_ARGS) { + ResponseTrigger trigger; + if (!responseFilter(allArgs, trigger)) { + return; + } + + auto ex = SockException(Err_other, StrPrinter << "unregister refuses by server, reason: " << allArgs[REASON_KEY]); + trigger(ex, getRoomKey()); +} + +void WebRtcSignalingPeer::sendCallRequest(const std::string& peer_room_id, const std::string& guest_id, const MediaTuple &tuple, const std::string& sdp, bool is_play, ResponseTrigger trigger) { + DebugL; + Json::Value body; + body[CLASS_KEY] = CLASS_VALUE_REQUEST; + body[METHOD_KEY] = METHOD_VALUE_CALL; + body[TYPE_KEY] = is_play? TYPE_VALUE_PLAY : TYPE_VALUE_PUSH; + body[GUEST_ID_KEY] = guest_id; //our guest id + body[ROOM_ID_KEY] = peer_room_id; + body[CALL_VHOST_KEY] = tuple.vhost; + body[CALL_APP_KEY] = tuple.app; + body[CALL_STREAM_KEY] = tuple.stream; + body[SDP_KEY] = sdp; + sendRequest(body, std::move(trigger)); +} + +void WebRtcSignalingPeer::sendCallAccept(const std::string& peer_guest_id, const MediaTuple &tuple, const std::string& sdp, bool is_play, const std::string& transaction_id) { + DebugL; + Json::Value body; + body[CLASS_KEY] = CLASS_VALUE_ACCEPT; + body[METHOD_KEY] = METHOD_VALUE_CALL; + body[TRANSACTION_ID_KEY] = transaction_id; + body[TYPE_KEY] = is_play? TYPE_VALUE_PLAY : TYPE_VALUE_PUSH; + body[GUEST_ID_KEY] = peer_guest_id; + body[ROOM_ID_KEY] = _room_id; //our room id + body[CALL_VHOST_KEY] = tuple.vhost; + body[CALL_APP_KEY] = tuple.app; + body[CALL_STREAM_KEY] = tuple.stream; + body[SDP_KEY] = sdp; + sendPacket(body); +} + +void WebRtcSignalingPeer::handleCallRequest(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY, CALL_VHOST_KEY, CALL_APP_KEY, CALL_STREAM_KEY, TYPE_KEY); + + if (allArgs[ROOM_ID_KEY] != getRoomId()) { + WarnL << "target room_id: " << allArgs[ROOM_ID_KEY] << "mismatch our room_id: " << getRoomId(); + return; + } + + auto args = std::make_shared>(allArgs, allArgs[GUEST_ID_KEY]); + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + WebRtcPluginManager::Instance().negotiateSdp(*this, allArgs[TYPE_KEY], *args, [allArgs, weak_self](const WebRtcInterface &exchanger) mutable { + if (auto strong_self = weak_self.lock()) { + strong_self->processOffer(allArgs, const_cast(exchanger)); + } + }); +} + +void WebRtcSignalingPeer::handleCallAccept(SIGNALING_MSG_ARGS) { + DebugL; + ResponseTrigger trigger; + if (!responseFilter(allArgs, trigger)) { + return; + } + + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY, CALL_VHOST_KEY, CALL_APP_KEY, CALL_STREAM_KEY, TYPE_KEY); + + auto room_id = allArgs[ROOM_ID_KEY]; + auto it = _tours.find(room_id); + if (it == _tours.end()) { + WarnL << "not found room_id: " << room_id << " in tours"; + return; + } + + auto &guest_id = it->second.first; + if (allArgs[GUEST_ID_KEY] != guest_id) { + WarnL << "guest_id: " << allArgs[GUEST_ID_KEY] << "mismatch our guest_id: " << guest_id; + return; + } + + trigger(SockException(Err_success), allArgs[SDP_KEY]); +} + +void WebRtcSignalingPeer::handleCallReject(SIGNALING_MSG_ARGS) { + DebugL; + ResponseTrigger trigger; + if (!responseFilter(allArgs, trigger)) { + return; + } + + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY, CALL_VHOST_KEY, CALL_APP_KEY, CALL_STREAM_KEY, TYPE_KEY); + + auto room_id = allArgs[ROOM_ID_KEY]; + auto it = _tours.find(room_id); + if (it == _tours.end()) { + WarnL << "not found room_id: " << room_id << " in tours"; + return; + } + + auto &guest_id = it->second.first; + if (allArgs[GUEST_ID_KEY] != guest_id) { + WarnL << "guest_id: " << allArgs[GUEST_ID_KEY] << "mismatch our guest_id: " << guest_id; + return; + } + + _tours.erase(room_id); + trigger(SockException(Err_other, StrPrinter << "call refuses by server, reason: " << allArgs[REASON_KEY]), ""); +} + +void WebRtcSignalingPeer::handleCandidateIndication(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY, CANDIDATE_KEY, UFRAG_KEY, PWD_KEY); + + std::string identifier; + std::string room_id = allArgs[ROOM_ID_KEY]; + std::string guest_id = allArgs[GUEST_ID_KEY]; + //作为被叫 + if (room_id == getRoomId()) { + auto it = _peer_guests.find(guest_id); + if (it == _peer_guests.end()) { + WarnL << "not found guest_id: " << guest_id; + return; + } + identifier = it->second; + } else { + //作为主叫 + for (auto it : _tours) { + if (room_id != it.first) { + continue; + } + + auto info = it.second; + if (guest_id != info.first) { + break; + } + identifier = info.second; + } + } + + TraceL << "recv remote candidate: " << allArgs[CANDIDATE_KEY]; + + if (identifier.empty()) { + WarnL << "target room_id: " << room_id << " not match our room_id: " << getRoomId() + << ", and target guest_id: " << guest_id << " not match"; + return; + } + + auto transport = WebRtcTransportManager::Instance().getItem(identifier); + if (transport) { + SdpAttrCandidate candidate_attr; + candidate_attr.parse(allArgs[CANDIDATE_KEY]); + transport->connectivityCheck(candidate_attr, allArgs[UFRAG_KEY], allArgs[PWD_KEY]); + } + +} + +void WebRtcSignalingPeer::handleByeIndication(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY); + + if (allArgs[ROOM_ID_KEY] != getRoomId()) { + WarnL << "target room_id: " << allArgs[ROOM_ID_KEY] << "not match our room_id: " << getRoomId(); + return; + } + + auto it = _peer_guests.find(allArgs[GUEST_ID_KEY]); + if (it == _peer_guests.end()) { + WarnL << "not found guest_id: " << allArgs[GUEST_ID_KEY]; + return; + } + + auto identifier = it->second; + _peer_guests.erase(it); + auto obj = WebRtcTransportManager::Instance().getItem(identifier); + if (obj) { + obj->safeShutdown(SockException(Err_shutdown, "deleted by websocket signaling server")); + } +} + +void WebRtcSignalingPeer::sendByeIndication(const std::string& peer_room_id, const std::string &guest_id) { + DebugL; + Json::Value body; + body[CLASS_KEY] = CLASS_VALUE_INDICATION; + body[METHOD_KEY] = METHOD_VALUE_BYE; + body[GUEST_ID_KEY] = guest_id; //our guest id + body[ROOM_ID_KEY] = peer_room_id; + body[SENDER_KEY] = guest_id; + sendIndication(body); +} + +void WebRtcSignalingPeer::sendCandidateIndication(const std::string& transport_identifier, const std::string& candidate, const std::string& ice_ufrag, const std::string& ice_pwd) { + TraceL; + Json::Value body; + body[CLASS_KEY] = CLASS_VALUE_INDICATION; + body[METHOD_KEY] = METHOD_VALUE_CANDIDATE; + body[CANDIDATE_KEY] = candidate; + body[UFRAG_KEY] = ice_ufrag; + body[PWD_KEY] = ice_pwd; + + //作为被叫 + for (auto &pr : _peer_guests) { + if (pr.second == transport_identifier) { + body[ROOM_ID_KEY] = _room_id; + body[GUEST_ID_KEY] = pr.first; //peer_guest_id + body[SENDER_KEY] = _room_id; + return sendIndication(body); + } + } + + //作为主叫 + for (auto &pr : _tours) { + auto &info = pr.second; + if (info.second == transport_identifier) { + body[ROOM_ID_KEY] = pr.first; //peer room id + body[GUEST_ID_KEY] = info.first; //our_guest_id + body[SENDER_KEY] = info.first; + return sendIndication(body); + } + } +} + +void WebRtcSignalingPeer::sendAcceptResponse(const std::string& method, const std::string& transaction_id, const std::string& room_id, + const std::string& guest_id, const std::string& reason) { + // TODO +} + +void WebRtcSignalingPeer::sendRefusesResponse(Json::Value &body, const std::string& transaction_id, const std::string& reason) { + body[CLASS_KEY] = CLASS_VALUE_REJECT; + body[REASON_KEY] = reason; + sendResponse(body, transaction_id); +} + +void WebRtcSignalingPeer::sendRequest(Json::Value& body, ResponseTrigger trigger, float seconds) { + auto transaction_id = makeRandStr(32); + body[TRANSACTION_ID_KEY] = transaction_id; + + ResponseTuple tuple; + tuple.ttl_ms = seconds * 1000; + tuple.method = body[METHOD_KEY].asString(); + tuple.cb = std::move(trigger); + _response_list.emplace(std::move(transaction_id), std::move(tuple)); + sendPacket(body); +} + +void WebRtcSignalingPeer::sendIndication(Json::Value &body) { + auto transaction_id = makeRandStr(32); + body[TRANSACTION_ID_KEY] = transaction_id; + sendPacket(body); +} + +void WebRtcSignalingPeer::sendResponse(Json::Value &body, const std::string& transaction_id) { + body[TRANSACTION_ID_KEY] = transaction_id; + sendPacket(body); +} + +void WebRtcSignalingPeer::sendPacket(Json::Value& body) { + auto msg = body.toStyledString(); + DebugL << "send msg: " << msg; + SockSender::send(std::move(msg)); +} + +Json::Value WebRtcSignalingPeer::makeInfoJson() { + Json::Value item; + item["room_id"] = getRoomId(); + item["room_key"] = getRoomKey(); + + Json::Value peer_guests_obj(Json::arrayValue); + auto peer_guests = _peer_guests; + for(auto &guest : peer_guests) { + Json::Value obj; + obj["guest_id"] = guest.first; + obj["transport_identifier"] = guest.second; + peer_guests_obj.append(std::move(obj)); + } + item["guests"] = std::move(peer_guests_obj); + + Json::Value tours_obj(Json::arrayValue); + auto tours = _tours; + for(auto &tour : tours){ + Json::Value obj; + obj["room_id"] = tour.first; + obj["guest_id"] = tour.second.first; + obj["transport_identifier"] = tour.second.second; + tours_obj.append(std::move(obj)); + } + item["tours"] = std::move(tours_obj); + return item; +} + +void WebRtcSignalingPeer::createResponseExpiredTimer() { + std::weak_ptr weak_self = std::static_pointer_cast(shared_from_this()); + _expire_timer = std::make_shared(0.2, [weak_self]() { + if (auto strong_self = weak_self.lock()) { + strong_self->checkResponseExpired(); + return true; // 继续定时器 + } + return false; + }, getPoller()); +} + +void WebRtcSignalingPeer::checkResponseExpired() { + //FIXME: 移动到专门的超时timer中处理 +#if 0 + // 设置计时器以检测 offer 响应超时 + _offer_timeout_timer = std::make_shared( + timeout_sec, + [this, cb, peer_room_id]() { + _tours.erase(peer_room_id); + return false; // 停止计时器 + }, + getPoller() + ); +#endif + + for (auto it = _response_list.begin(); it != _response_list.end();) { + auto &tuple = it->second; + if (!tuple.expired()) { + ++it; + continue; + } + // over time + WarnL << "transaction_id: " << it->first << ", method: " << tuple.method << " recv response over time"; + tuple.cb(SockException(Err_timeout, "recv response timeout"), ""); + it = _response_list.erase(it); + } +} + +}// namespace mediakit diff --git a/webrtc/WebRtcSignalingPeer.h b/webrtc/WebRtcSignalingPeer.h new file mode 100644 index 00000000..7eaf179f --- /dev/null +++ b/webrtc/WebRtcSignalingPeer.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + + +#ifndef ZLMEDIAKIT_WEBRTC_SIGNALING_PEER_H +#define ZLMEDIAKIT_WEBRTC_SIGNALING_PEER_H + +#include +#include "Poller/Timer.h" +#include "Network/Session.h" +#include "Http/WebSocketClient.h" +#include "webrtc/WebRtcSignalingMsg.h" +#include "webrtc/WebRtcTransport.h" + +namespace mediakit { + +class WebRtcSignalingPeer : public WebSocketClient { +public: + struct ClassMethodHash { + bool operator()(std::pair key) const { + std::size_t h = 0; + h ^= std::hash()(key.first) << 0; + h ^= std::hash()(key.second) << 1; + return h; + } + }; + using Ptr = std::shared_ptr; + WebRtcSignalingPeer(const std::string &host, uint16_t port, bool ssl, const std::string &room_id, const toolkit::EventPoller::Ptr &poller = nullptr); + virtual ~WebRtcSignalingPeer(); + + void connect(); + void regist(const std::function &cb); + void unregist(const std::function &cb); + void checkIn(const std::string& peer_room_id, const MediaTuple &tuple, const std::string& identifier, + const std::string& offer, bool is_play, const std::function &cb, float timeout_sec); + void checkOut(const std::string& peer_room_id); + void candidate(const std::string& transport_identifier, const std::string& candidate, const std::string& ice_ufrag, const std::string& ice_pwd); + + void processOffer(SIGNALING_MSG_ARGS, WebRtcInterface &transport); + void answer(const std::string& guest_id, const MediaTuple &tuple, const std::string& identifier, const std::string& sdp, bool is_play, const std::string& transaction_id); + + const std::string& getRoomKey() const { + return _room_key; + } + + const std::string& getRoomId() const { + return _room_id; + } + + const RTC::IceServerInfo::Ptr& getIceServer() const { + return _ice_server; + } + + //// TcpClient override//// + void setOnConnect(std::function cb); + void onConnect(const toolkit::SockException &ex) override; + void setOnShutdown(std::function cb); + void onShutdown(const toolkit::SockException &ex); + void onRecv(const toolkit::Buffer::Ptr &) override; + void onError(const toolkit::SockException &err) override; + + Json::Value makeInfoJson(); + +protected: + void checkResponseExpired(); + void createResponseExpiredTimer(); + + using ResponseTrigger = std::function; + struct ResponseTuple { + toolkit::Ticker ticker; + uint32_t ttl_ms; + std::string method; + ResponseTrigger cb; + + bool expired() { + return ticker.elapsedTime() > ttl_ms; + } + }; + + bool responseFilter(SIGNALING_MSG_ARGS, ResponseTrigger& trigger); + + void sendRegisterRequest(ResponseTrigger trigger); + void handleRegisterAccept(SIGNALING_MSG_ARGS); + + void handleRegisterReject(SIGNALING_MSG_ARGS); + void sendUnregisterRequest(ResponseTrigger trigger); + void handleUnregisterAccept(SIGNALING_MSG_ARGS); + void handleUnregisterReject(SIGNALING_MSG_ARGS); + + void sendCallRequest(const std::string& peer_room_id, const std::string& guest_id, const MediaTuple &tuple, const std::string& sdp, bool is_play, ResponseTrigger trigger); + void sendCallAccept(const std::string& peer_guest_id, const MediaTuple &tuple, const std::string& sdp, bool is_play, const std::string& transaction_id); + void handleCallRequest(SIGNALING_MSG_ARGS); + void handleCallAccept(SIGNALING_MSG_ARGS); + void handleCallReject(SIGNALING_MSG_ARGS); + + void sendCandidateIndication(const std::string& transport_identifier, const std::string& candidate, const std::string& ice_ufrag, const std::string& ice_pwd); + void handleCandidateIndication(SIGNALING_MSG_ARGS); + + void sendByeIndication(const std::string& peer_room_id, const std::string &guest_id); + void handleByeIndication(SIGNALING_MSG_ARGS); + + void sendAcceptResponse(const std::string& method, const std::string& transaction_id, const std::string& room_id, const std::string& guest_id, const std::string& reason); + void sendRefusesResponse(Json::Value &body, const std::string& transaction_id, const std::string& reason); + + void sendIndication(Json::Value &body); + void sendRequest(Json::Value& body, ResponseTrigger trigger, float seconds = 10); + void sendResponse(Json::Value &body, const std::string& transaction_id); + void sendPacket(Json::Value& body); + +private: + toolkit::Timer::Ptr _expire_timer; + std::string _ws_url; + std::string _room_key; + std::string _room_id; + std::unordered_map _peer_guests; //作为被叫 + std::unordered_map> _tours; //作为主叫 + RTC::IceServerInfo::Ptr _ice_server; + std::unordered_map _response_list; + + std::function _on_connect; + std::function _on_shutdown; + toolkit::Timer::Ptr _offer_timeout_timer = nullptr; +}; + +void addWebrtcRoomKeeper(const std::string &host, uint16_t port, const std::string& room_id, bool ssl, + const std::function &cb); +void delWebrtcRoomKeeper(const std::string &key, const std::function &cb); +void listWebrtcRoomKeepers(const std::function &cb); +Json::Value ToJson(const WebRtcSignalingPeer::Ptr& p); +WebRtcSignalingPeer::Ptr getWebrtcRoomKeeper(const std::string &host, uint16_t port); + +} // namespace mediakit + +#endif // ZLMEDIAKIT_WEBRTC_SIGNALING_PEER_H diff --git a/webrtc/WebRtcSignalingSession.cpp b/webrtc/WebRtcSignalingSession.cpp new file mode 100644 index 00000000..4f8e6d79 --- /dev/null +++ b/webrtc/WebRtcSignalingSession.cpp @@ -0,0 +1,488 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "Util/util.h" +#include "Common/config.h" +#include "WebRtcTransport.h" +#include "WebRtcSignalingMsg.h" +#include "WebRtcSignalingSession.h" + +using namespace std; +using namespace toolkit; +using namespace mediakit::Rtc; + +namespace mediakit { + +// 注册上来的peer列表 +static std::atomic s_room_idx_generate { 1 }; +static ServiceController s_rooms; + +void listWebrtcRooms(const std::function &cb) { + s_rooms.for_each(cb); +} + +Json::Value ToJson(const WebRtcSignalingSession::Ptr &p) { + return p->makeInfoJson(); +} + +WebRtcSignalingSession::Ptr getWebrtcRoomKeeper(const string &room_id) { + return s_rooms.find(room_id); +} + +//////////// WebRtcSignalingSession ////////////////////////// + +WebRtcSignalingSession::WebRtcSignalingSession(const Socket::Ptr &sock) : Session(sock) { + DebugL; +} + +WebRtcSignalingSession::~WebRtcSignalingSession() { + DebugL << "room_id: " << _room_id; +} + +void WebRtcSignalingSession::onRecv(const Buffer::Ptr &buffer) { + DebugL << "recv msg:\r\n" << buffer->data(); + + Json::Value args; + Json::Reader reader; + reader.parse(buffer->data(), args); + Parser parser; + HttpAllArgs allArgs(parser, args); + + using MsgHandler = void (WebRtcSignalingSession::*)(SIGNALING_MSG_ARGS); + static std::unordered_map, MsgHandler, ClassMethodHash> s_msg_handlers; + + static onceToken token([]() { + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REQUEST, METHOD_VALUE_REGISTER), &WebRtcSignalingSession::handleRegisterRequest); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REQUEST, METHOD_VALUE_UNREGISTER), &WebRtcSignalingSession::handleUnregisterRequest); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REQUEST, METHOD_VALUE_CALL), &WebRtcSignalingSession::handleCallRequest); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_ACCEPT, METHOD_VALUE_CALL), &WebRtcSignalingSession::handleCallAccept); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_REJECT, METHOD_VALUE_CALL), &WebRtcSignalingSession::handleCallReject); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_INDICATION, METHOD_VALUE_BYE), &WebRtcSignalingSession::handleByeIndication); + s_msg_handlers.emplace(std::make_pair(CLASS_VALUE_INDICATION, METHOD_VALUE_CANDIDATE), &WebRtcSignalingSession::handleCandidateIndication); + }); + + try { + CHECK_ARGS(CLASS_KEY, METHOD_KEY, TRANSACTION_ID_KEY); + auto it = s_msg_handlers.find(std::make_pair(allArgs[CLASS_KEY], allArgs[METHOD_KEY])); + if (it == s_msg_handlers.end()) { + WarnL << " not support class: " << allArgs[CLASS_KEY] << ", method: " << allArgs[METHOD_KEY] << ", ignore"; + return; + } + + (this->*(it->second))(allArgs); + } catch (std::exception &ex) { + WarnL << "process msg fail: " << ex.what(); + } +} + +void WebRtcSignalingSession::onError(const SockException &err) { + WarnL << "room_id: " << _room_id; + notifyByeIndication(); + s_rooms.erase(_room_id); +} + +void WebRtcSignalingSession::onManager() { + // Websocket会话会自行定时发送PING/PONG 消息,并进行超时自己管理,该对象暂时不需要心跳超时处理 +} + +void WebRtcSignalingSession::handleRegisterRequest(SIGNALING_MSG_ARGS) { + DebugL; + + std::string room_id; + Json::Value body; + body[METHOD_KEY] = METHOD_VALUE_REGISTER; + + // 如果客户端没有提供 room_id,服务端自动分配一个 + if (allArgs[ROOM_ID_KEY].empty()) { + auto idx = s_room_idx_generate.fetch_add(1); + room_id = std::to_string(idx) + "_" + makeRandStr(16); + DebugL << "auto generated room_id: " << room_id; + } else { + room_id = allArgs[ROOM_ID_KEY]; + if (s_rooms.find(room_id)) { + // 已经注册了 + body[ROOM_ID_KEY] = room_id; + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], "room id conflict"); + } + } + + body[ROOM_ID_KEY] = room_id; + + _room_id = room_id; + s_rooms.emplace(_room_id, shared_from_this()); + sendRegisterAccept(body, allArgs[TRANSACTION_ID_KEY]); +} + +void WebRtcSignalingSession::handleUnregisterRequest(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(ROOM_ID_KEY); + + Json::Value body; + body[METHOD_KEY] = METHOD_VALUE_UNREGISTER; + body[ROOM_ID_KEY] = allArgs[ROOM_ID_KEY]; + + if (_room_id.empty()) { + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], "unregistered"); + } + + if (allArgs[ROOM_ID_KEY] != getRoomId()) { + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], StrPrinter << "room_id: \"" << allArgs[ROOM_ID_KEY] << "\" not match room_id:" << getRoomId()); + } + + sendAcceptResponse(body, allArgs[TRANSACTION_ID_KEY]); + + // 同时主动向所有连接的对端会话发送bye + notifyByeIndication(); + + if (s_rooms.find(_room_id)) { + s_rooms.erase(_room_id); + } +} + +void WebRtcSignalingSession::handleCallRequest(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(TRANSACTION_ID_KEY, GUEST_ID_KEY, ROOM_ID_KEY, CALL_VHOST_KEY, CALL_APP_KEY, CALL_STREAM_KEY, TYPE_KEY, SDP_KEY); + + Json::Value body; + body[METHOD_KEY] = METHOD_VALUE_CALL; + body[ROOM_ID_KEY] = allArgs[ROOM_ID_KEY]; + body[GUEST_ID_KEY] = allArgs[GUEST_ID_KEY]; + body[CALL_VHOST_KEY] = allArgs[CALL_VHOST_KEY]; + body[CALL_APP_KEY] = allArgs[CALL_APP_KEY]; + body[CALL_STREAM_KEY] = allArgs[CALL_STREAM_KEY]; + body[TYPE_KEY] = allArgs[TYPE_KEY]; + if (_room_id.empty()) { + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], "should register first"); + } + auto peer_id = allArgs[ROOM_ID_KEY]; + auto session = getWebrtcRoomKeeper(peer_id); + if (!session) { + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], StrPrinter << "room_id: \"" << peer_id << "\" unregistered"); + } + + _tours.emplace(allArgs[GUEST_ID_KEY], peer_id); + // forwardOffer + weak_ptr sender_ptr = static_pointer_cast(shared_from_this()); + session->forwardCallRequest(sender_ptr, allArgs); +} + +void WebRtcSignalingSession::handleCallAccept(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY, CALL_VHOST_KEY, CALL_APP_KEY, CALL_STREAM_KEY); + + Json::Value body; + body[ROOM_ID_KEY] = allArgs[ROOM_ID_KEY]; + + if (_room_id.empty()) { + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], "should register first"); + } + + auto it = _guests.find(allArgs[GUEST_ID_KEY]); + if (it == _guests.end()) { + WarnL << "guest_id: \"" << allArgs[GUEST_ID_KEY] << "\" not register"; + return; + } + auto session = it->second.lock(); + if (!session) { + WarnL << "guest_id: \"" << allArgs[GUEST_ID_KEY] << "\" leave alreadly"; + return; + } + + session->forwardCallAccept(allArgs); +} + +void WebRtcSignalingSession::handleByeIndication(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(GUEST_ID_KEY, ROOM_ID_KEY); + auto guest_id = allArgs[GUEST_ID_KEY]; + + Json::Value body; + body[METHOD_KEY] = METHOD_VALUE_BYE; + body[ROOM_ID_KEY] = allArgs[ROOM_ID_KEY]; + body[GUEST_ID_KEY] = guest_id; + if (_room_id.empty()) { + return sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], "should register first"); + } + if (allArgs[ROOM_ID_KEY] == getRoomId()) { + // 作为被叫方,接收bye + auto it = _guests.find(guest_id); + if (it == _guests.end()) { + WarnL << "guest_id: \"" << guest_id << "\" not register"; + return; + } + auto session = it->second.lock(); + if (!session) { + WarnL << "guest_id: \"" << guest_id << "\" leave alreadly"; + return; + } + _guests.erase(guest_id); + session->forwardBye(allArgs); + } else { + // 作为主叫方,接受bye + auto session = getWebrtcRoomKeeper(allArgs[ROOM_ID_KEY]); + if (!session) { + WarnL << "room_id: \"" << allArgs[ROOM_ID_KEY] << "\" not register"; + return; + } + _tours.erase(guest_id); + session->forwardBye(allArgs); + } +} + +void WebRtcSignalingSession::handleCandidateIndication(SIGNALING_MSG_ARGS) { + DebugL; + CHECK_ARGS(TRANSACTION_ID_KEY, GUEST_ID_KEY, ROOM_ID_KEY, CANDIDATE_KEY, UFRAG_KEY, PWD_KEY); + + Json::Value body; + body[METHOD_KEY] = METHOD_VALUE_CANDIDATE; + body[ROOM_ID_KEY] = allArgs[ROOM_ID_KEY]; + + if (_room_id.empty()) { + sendRejectResponse(body, allArgs[TRANSACTION_ID_KEY], "should register first"); + } else { + handleOtherMsg(allArgs); + } +} + +void WebRtcSignalingSession::handleOtherMsg(SIGNALING_MSG_ARGS) { + DebugL; + if (allArgs[ROOM_ID_KEY] == getRoomId()) { + // 作为被叫方,接收bye + auto guest_id = allArgs[GUEST_ID_KEY]; + auto it = _guests.find(guest_id); + if (it == _guests.end()) { + WarnL << "guest_id: \"" << guest_id << "\" not register"; + return; + } + auto session = it->second.lock(); + if (!session) { + WarnL << "guest_id: \"" << guest_id << "\" leave alreadly"; + return; + } + + session->forwardPacket(allArgs); + } else { + // 作为主叫方,接受bye + auto session = getWebrtcRoomKeeper(allArgs[ROOM_ID_KEY]); + if (!session) { + WarnL << "room_id: \"" << allArgs[ROOM_ID_KEY] << "\" not register"; + return; + } + session->forwardPacket(allArgs); + } +} + +void WebRtcSignalingSession::notifyByeIndication() { + DebugL; + + Json::Value allArgs; + allArgs[CLASS_KEY] = CLASS_VALUE_INDICATION; + allArgs[METHOD_KEY] = METHOD_VALUE_BYE; + allArgs[REASON_KEY] = "peer unregister"; + // 作为被叫方 + for (auto it : _guests) { + auto session = it.second.lock(); + if (session) { + allArgs[TRANSACTION_ID_KEY] = makeRandStr(32); + allArgs[GUEST_ID_KEY] = it.first; + allArgs[ROOM_ID_KEY] = getRoomId(); + session->forwardBye(allArgs); + } + } + + // 作为主叫方 + for (auto it : _tours) { + auto guest_id = it.first; + auto peer_room_id = it.second; + auto session = getWebrtcRoomKeeper(peer_room_id); + if (session) { + allArgs[TRANSACTION_ID_KEY] = makeRandStr(32); + allArgs[GUEST_ID_KEY] = guest_id; + allArgs[ROOM_ID_KEY] = peer_room_id; + session->forwardBye(allArgs); + } + } +} + +void WebRtcSignalingSession::forwardCallRequest(WebRtcSignalingSession::WeakPtr sender, SIGNALING_MSG_ARGS) { + DebugL; + WeakPtr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, sender, allArgs]() { + if (auto strong_self = weak_self.lock()) { + strong_self->_guests.emplace(allArgs[GUEST_ID_KEY], sender); + strong_self->sendPacket(allArgs.getArgs()); + } + }); +} + +void WebRtcSignalingSession::forwardCallAccept(SIGNALING_MSG_ARGS) { + DebugL; + WeakPtr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, allArgs]() { + if (auto strong_self = weak_self.lock()) { + strong_self->sendPacket(allArgs.getArgs()); + } + }); +} + +void WebRtcSignalingSession::forwardBye(SIGNALING_MSG_ARGS) { + DebugL; + WeakPtr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, allArgs]() { + if (auto strong_self = weak_self.lock()) { + if (allArgs[ROOM_ID_KEY] == strong_self->getRoomId()) { + // 作为被叫 + strong_self->_guests.erase(allArgs[GUEST_ID_KEY]); + } else { + // 作为主叫 + strong_self->_tours.erase(allArgs[GUEST_ID_KEY]); + } + strong_self->sendPacket(allArgs.getArgs()); + } + }); +} + +void WebRtcSignalingSession::forwardBye(Json::Value allArgs) { + DebugL; + WeakPtr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, allArgs]() { + if (auto strong_self = weak_self.lock()) { + if (allArgs[ROOM_ID_KEY] == strong_self->getRoomId()) { + // 作为被叫 + strong_self->_guests.erase(allArgs[GUEST_ID_KEY].asString()); + } else { + // 作为主叫 + strong_self->_tours.erase(allArgs[GUEST_ID_KEY].asString()); + } + strong_self->sendPacket(allArgs); + } + }); +} + +void WebRtcSignalingSession::forwardPacket(SIGNALING_MSG_ARGS) { + WeakPtr weak_self = std::static_pointer_cast(shared_from_this()); + getPoller()->async([weak_self, allArgs]() { + if (auto strong_self = weak_self.lock()) { + strong_self->sendPacket(allArgs.getArgs()); + } + }); +} + +void WebRtcSignalingSession::sendRegisterAccept(Json::Value& body, const std::string& transaction_id) { + DebugL; + body[CLASS_KEY] = CLASS_VALUE_ACCEPT; + + Json::Value ice_server; + GET_CONFIG(uint16_t, icePort, Rtc::kIcePort); + GET_CONFIG(bool, enable_turn, Rtc::kEnableTurn); + GET_CONFIG(string, iceUfrag, Rtc::kIceUfrag); + GET_CONFIG(string, icePwd, Rtc::kIcePwd); + GET_CONFIG_FUNC(std::vector, extern_ips, Rtc::kExternIP, [](string str) { + std::vector ret; + if (str.length()) { + ret = split(str, ","); + } + translateIPFromEnv(ret); + return ret; + }); + + // 如果配置了extern_ips, 则选择第一个作为turn服务器的ip + // 如果没配置获取网卡接口 + std::string extern_ip; + if (!extern_ips.empty()) { + extern_ip = extern_ips.front(); + } else { + extern_ip = SockUtil::get_local_ip(); + } + + // TODO: support multi extern ip + // TODO: support third stun/turn server + + std::string url; + // SUPPORT: + // stun:host:port?transport=udp + // turn:host:port?transport=udp + + // NOT SUPPORT NOW TODO: + // turns:host:port?transport=udp + // turn:host:port?transport=tcp + // turns:host:port?transport=tcp + // stuns:host:port?transport=udp + // stuns:host:port?transport=udp + // stun:host:port?transport=tcp + if (enable_turn) { + url = "turn:" + extern_ip + ":" + std::to_string(icePort) + "?transport=udp"; + } else { + url = "stun:" + extern_ip + ":" + std::to_string(icePort) + "?transport=udp"; + } + + ice_server[URL_KEY] = url; + ice_server[UFRAG_KEY] = iceUfrag; + ice_server[PWD_KEY] = icePwd; + + Json::Value ice_servers; + ice_servers.append(ice_server); + + body[ICE_SERVERS_KEY] = ice_servers; + + sendAcceptResponse(body, transaction_id); +} + +void WebRtcSignalingSession::sendAcceptResponse(Json::Value &body, const std::string &transaction_id) { + TraceL; + body[CLASS_KEY] = CLASS_VALUE_ACCEPT; + return sendResponse(body, transaction_id); +} + +void WebRtcSignalingSession::sendRejectResponse(Json::Value &body, const std::string &transaction_id, const std::string &reason) { + DebugL; + body[CLASS_KEY] = CLASS_VALUE_REJECT; + body[REASON_KEY] = reason; + return sendResponse(body, transaction_id); +} + +void WebRtcSignalingSession::sendResponse(Json::Value &body, const std::string &transaction_id) { + DebugL; + body[TRANSACTION_ID_KEY] = transaction_id; + return sendPacket(body); +} + +void WebRtcSignalingSession::sendPacket(const Json::Value &body) { + auto msg = body.toStyledString(); + TraceL << "send msg: " << msg; + SockSender::send(msg); +} + +Json::Value WebRtcSignalingSession::makeInfoJson() { + Json::Value item; + item["room_id"] = getRoomId(); + + Json::Value tours_obj(Json::arrayValue); + auto tours = _tours; + for (auto &tour : tours) { + Json::Value obj; + obj["guest_id"] = tour.first; + obj["room_id"] = tour.second; + tours_obj.append(std::move(obj)); + } + item["tours"] = std::move(tours_obj); + + Json::Value guests_obj(Json::arrayValue); + auto guests = _guests; + for (auto &guest : guests) { + Json::Value obj; + obj["guest_id"] = guest.first; + guests_obj.append(std::move(obj)); + } + item["guests"] = std::move(guests_obj); + return item; +} + +} // namespace mediakit diff --git a/webrtc/WebRtcSignalingSession.h b/webrtc/WebRtcSignalingSession.h new file mode 100644 index 00000000..557895a0 --- /dev/null +++ b/webrtc/WebRtcSignalingSession.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTC_SIGNALING_SESSION_H +#define ZLMEDIAKIT_WEBRTC_SIGNALING_SESSION_H + +#include "Network/Session.h" +#include "Http/WebSocketSession.h" +#include "webrtc/WebRtcSignalingMsg.h" + +namespace mediakit { + +// webrtc 信令, 基于websocket实现 +class WebRtcSignalingSession : public toolkit::Session { +public: + struct ClassMethodHash { + bool operator()(std::pair key) const { + std::size_t h = 0; + h ^= std::hash()(key.first) << 0; + h ^= std::hash()(key.second) << 1; + return h; + } + }; + + using Ptr = std::shared_ptr; + using WeakPtr = std::weak_ptr; + + WebRtcSignalingSession(const toolkit::Socket::Ptr &sock); + virtual ~WebRtcSignalingSession(); + + Json::Value makeInfoJson(); + + std::string getRoomId() { return _room_id; }; + + //// Session override//// + void onRecv(const toolkit::Buffer::Ptr &) override; + void onError(const toolkit::SockException &err) override; + void onManager() override; + +protected: + void handleRegisterRequest(SIGNALING_MSG_ARGS); + void handleUnregisterRequest(SIGNALING_MSG_ARGS); + void handleCallRequest(SIGNALING_MSG_ARGS); + void handleCallAccept(SIGNALING_MSG_ARGS); + #define handleCallReject handleCallAccept + void handleByeIndication(SIGNALING_MSG_ARGS); + void handleCandidateIndication(SIGNALING_MSG_ARGS); + void handleOtherMsg(SIGNALING_MSG_ARGS); + + void notifyByeIndication(); + void forwardCallRequest(WebRtcSignalingSession::WeakPtr sender, SIGNALING_MSG_ARGS); + void forwardCallAccept(SIGNALING_MSG_ARGS); + void forwardBye(SIGNALING_MSG_ARGS); + void forwardBye(Json::Value allArgs); + void forwardPacket(SIGNALING_MSG_ARGS); + + void sendRegisterAccept(Json::Value& body, const std::string& transaction_id); + void sendAcceptResponse(Json::Value &body, const std::string& transaction_id); + void sendRejectResponse(Json::Value &body, const std::string& transaction_id, const std::string& reason); + + void sendResponse(Json::Value &body, const std::string& transaction_id); + void sendPacket(const Json::Value &body); + +private: + std::string _room_id; // + std::unordered_map _tours; //作为主叫 + std::unordered_map _guests; //作为被叫 +}; + +using WebRtcWebcosktSignalingSession = WebSocketSession; +using WebRtcWebcosktSignalSslSession = WebSocketSession; + +void listWebrtcRooms(const std::function &cb); +Json::Value ToJson(const WebRtcSignalingSession::Ptr& p); +WebRtcSignalingSession::Ptr getWebrtcRoomKeeper(const std::string &room_id); +}// namespace mediakit + +#endif //ZLMEDIAKIT_WEBRTC_SIGNALING_SESSION_H diff --git a/webrtc/WebRtcTalk.cpp b/webrtc/WebRtcTalk.cpp new file mode 100644 index 00000000..1c4dc4ce --- /dev/null +++ b/webrtc/WebRtcTalk.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#include "WebRtcTalk.h" + +#include "Util/base64.h" +#include "Common/config.h" +#include "Extension/Factory.h" +#include "Common/MultiMediaSourceMuxer.h" + +using namespace std; +using namespace toolkit; + +namespace mediakit { + +WebRtcTalk::Ptr WebRtcTalk::create( + const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info, WebRtcTransport::Role role, + WebRtcTransport::SignalingProtocols signaling_protocols) { + WebRtcTalk::Ptr ret(new WebRtcTalk(poller, src, info), [](WebRtcTalk *ptr) { + ptr->onDestory(); + delete ptr; + }); + ret->setRole(role); + ret->setSignalingProtocols(signaling_protocols); + ret->onCreate(); + return ret; +} + +WebRtcTalk::WebRtcTalk(const EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info) + : WebRtcTransportImp(poller) { + _media_info = info; + _play_src = src; + CHECK(src); + _demuxer = std::make_shared(); +} + +void WebRtcTalk::onStartWebRTC() { + auto playSrc = _play_src.lock(); + if (!playSrc) { + onShutdown(SockException(Err_shutdown, "rtsp media source was shutdown")); + return; + } + WebRtcTransportImp::onStartWebRTC(); + // 不支持simulcast + CHECK(!_answer_sdp->supportSimulcast()); + auto sdp = _answer_sdp->toRtspSdp(); + _demuxer->loadSdp(sdp); + auto audio_track = _demuxer->getTrack(TrackAudio, false); + // 必须包含音频track + CHECK(audio_track); + audio_track->addDelegate([this](const Frame::Ptr &frame) { + // 发送对讲语音rtp流 + _sender->inputFrame(frame); + return true; + }); + + MediaSourceEvent::SendRtpArgs args; + args.con_type = MediaSourceEvent::SendRtpArgs::kVoiceTalk; + args.recv_stream_vhost = playSrc->getMediaTuple().vhost; + args.recv_stream_app = playSrc->getMediaTuple().app; + args.recv_stream_id = playSrc->getMediaTuple().stream; + auto url_args = Parser::parseArgs(_media_info.params); + args.data_type = static_cast(atoi(url_args["data_type"].data())); + args.only_audio = true; + args.pt = static_cast(atoi(url_args["pt"].data())); + args.ssrc = url_args["ssrc"]; + + std::weak_ptr weak_self = static_pointer_cast(shared_from_this()); + _sender = std::make_shared(getPoller()); + _sender->startSend(*(playSrc->getMuxer()), args, [weak_self](uint16_t local_port, const SockException &ex) { + if (!ex) { + return; + } + if (auto strong_self = weak_self.lock()) { + strong_self->onShutdown(ex); + } + }); + + _sender->addTrack(audio_track); + _sender->addTrackCompleted(); + + if (canSendRtp()) { + playSrc->pause(false); + _reader = playSrc->getRing()->attach(getPoller(), true); + weak_ptr weak_self = static_pointer_cast(shared_from_this()); + weak_ptr weak_session = static_pointer_cast(getSession()); + _reader->setGetInfoCB([weak_session]() { + Any ret; + ret.set(static_pointer_cast(weak_session.lock())); + return ret; + }); + _reader->setReadCB([weak_self](const RtspMediaSource::RingDataType &pkt) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + + size_t i = 0; + pkt->for_each([&](const RtpPacket::Ptr &rtp) { strong_self->onSendRtp(rtp, ++i == pkt->size()); }); + }); + _reader->setDetachCB([weak_self]() { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + strong_self->onShutdown(SockException(Err_shutdown, "rtsp ring buffer detached")); + }); + + _reader->setMessageCB([weak_self](const toolkit::Any &data) { + auto strong_self = weak_self.lock(); + if (!strong_self) { + return; + } + if (data.is()) { + auto &buffer = data.get(); + // PPID 51: 文本string [AUTO-TRANSLATED:69a8cf81] + // PPID 51: Text string + // PPID 53: 二进制 [AUTO-TRANSLATED:faf00c3e] + // PPID 53: Binary + strong_self->sendDatachannel(0, 51, buffer.data(), buffer.size()); + } else { + WarnL << "Send unknown message type to webrtc player: " << data.type_name(); + } + }); + } +} +void WebRtcTalk::onDestory() { + auto duration = getDuration(); + auto bytes_usage = getBytesUsage(); + // 流量统计事件广播 [AUTO-TRANSLATED:6b0b1234] + // Traffic statistics event broadcast + GET_CONFIG(uint32_t, iFlowThreshold, General::kFlowThreshold); + auto session = getSession(); + if (_reader && session) { + WarnL << "RTC对讲(" << _media_info.shortUrl() << ")结束播放,耗时(s):" << duration; + if (bytes_usage >= iFlowThreshold * 1024) { + NOTICE_EMIT(BroadcastFlowReportArgs, Broadcast::kBroadcastFlowReport, _media_info, bytes_usage, duration, true, *session); + } + } + WebRtcTransportImp::onDestory(); +} + +void WebRtcTalk::onRtcConfigure(RtcConfigure &configure) const { + WebRtcTransportImp::onRtcConfigure(configure); + auto playSrc = _play_src.lock(); + if (playSrc) { + configure.setPlayRtspInfo(playSrc->getSdp()); + } + + // 不接收视频 + configure.video.direction = static_cast(static_cast(configure.video.direction) & ~static_cast(RtpDirection::recvonly)); + // 开启音频接收 + configure.audio.direction = static_cast(static_cast(configure.audio.direction) | static_cast(RtpDirection::recvonly)); +} + +void WebRtcTalk::onRecvRtp(MediaTrack &track, const std::string &rid, RtpPacket::Ptr rtp) { + // rtp解析为音频,视频丢弃 + if (rtp->type == TrackAudio) { + _demuxer->inputRtp(rtp); + } +} + + +} // namespace mediakit \ No newline at end of file diff --git a/webrtc/WebRtcTalk.h b/webrtc/WebRtcTalk.h new file mode 100644 index 00000000..c6a3bc05 --- /dev/null +++ b/webrtc/WebRtcTalk.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved. + * + * This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit). + * + * Use of this source code is governed by MIT-like license that can be found in the + * LICENSE file in the root of the source tree. All contributing project authors + * may be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef ZLMEDIAKIT_WEBRTC_TALK_H +#define ZLMEDIAKIT_WEBRTC_TALK_H + +#include "WebRtcTransport.h" +#include "Rtsp/RtspMediaSource.h" +#include "Rtsp/RtspDemuxer.h" +#include "Rtp/RtpSender.h" + +namespace mediakit { + +class WebRtcTalk : public WebRtcTransportImp { +public: + using Ptr = std::shared_ptr; + static Ptr create(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info, + WebRtcTransport::Role role, WebRtcTransport::SignalingProtocols signaling_protocols); + +protected: + ///////WebRtcTransportImp override/////// + void onStartWebRTC() override; + void onDestory() override; + void onRtcConfigure(RtcConfigure &configure) const override; + void onRecvRtp(MediaTrack &track, const std::string &rid, RtpPacket::Ptr rtp) override; + +private: + WebRtcTalk(const toolkit::EventPoller::Ptr &poller, const RtspMediaSource::Ptr &src, const MediaInfo &info); + +private: + // 媒体相关元数据 [AUTO-TRANSLATED:f4cf8045] + // Media related metadata + MediaInfo _media_info; + // 播放的rtsp源 [AUTO-TRANSLATED:9963eed1] + // Playing rtsp source + std::weak_ptr _play_src; + + // 播放rtsp源的reader对象 [AUTO-TRANSLATED:7b305055] + // Reader object for playing rtsp source + RtspMediaSource::RingType::RingReader::Ptr _reader; + + // 解析对讲语音rtp流为帧数据 + RtspDemuxer::Ptr _demuxer; + // 打包语音帧数据为特定rtp并回复过去 + RtpSender::Ptr _sender; +}; + +}// namespace mediakit +#endif // ZLMEDIAKIT_WEBRTC_TALK_H diff --git a/webrtc/WebRtcTransport.cpp b/webrtc/WebRtcTransport.cpp index 1dbe2bcd..cf492de5 100644 --- a/webrtc/WebRtcTransport.cpp +++ b/webrtc/WebRtcTransport.cpp @@ -9,6 +9,9 @@ */ #include +#include +#include +#include #include #include "Util/base64.h" #include "Network/sockutil.h" @@ -25,6 +28,7 @@ #include "WebRtcEchoTest.h" #include "WebRtcPlayer.h" #include "WebRtcPusher.h" +#include "WebRtcTalk.h" #include "Rtsp/RtspMediaSourceImp.h" #define RTP_SSRC_OFFSET 1 @@ -35,6 +39,7 @@ using namespace std; +using namespace toolkit; namespace mediakit { // RTC配置项目 [AUTO-TRANSLATED:19940011] @@ -47,6 +52,7 @@ const string kTimeOutSec = RTC_FIELD "timeoutSec"; // 服务器外网ip [AUTO-TRANSLATED:23283ba6] // Server external network ip const string kExternIP = RTC_FIELD "externIP"; +const string kInterfaces = RTC_FIELD "interfaces"; // 设置remb比特率,非0时关闭twcc并开启remb。该设置在rtc推流时有效,可以控制推流画质 [AUTO-TRANSLATED:412801db] // Set remb bitrate, when it is not 0, turn off twcc and turn on remb. This setting is valid when rtc pushes the stream, and can control the pushing stream quality const string kRembBitRate = RTC_FIELD "rembBitRate"; @@ -57,6 +63,18 @@ const string kTranscodeG711 = RTC_FIELD "transcodeG711"; // webrtc single-port udp server const string kPort = RTC_FIELD "port"; const string kTcpPort = RTC_FIELD "tcpPort"; +// webrtc SignalingServerPort udp server +const string kSignalingPort = RTC_FIELD "signalingPort"; +const string kSignalingSslPort = RTC_FIELD "signalingSslPort"; +// webrtc iceServer udp server +const string kIcePort = RTC_FIELD "icePort"; +const string kIceTcpPort = RTC_FIELD "iceTcpPort"; +// webrtc enable turn or only enable stun +const string kEnableTurn = RTC_FIELD "enableTurn"; +// webrtc ice ufrag and pwd [AUTO-TRANSLATED:2f0d1b3c] +const string kIceUfrag = RTC_FIELD "iceUfrag"; +const string kIcePwd = RTC_FIELD "icePwd"; +const string kIceTransportPolicy = RTC_FIELD "iceTransportPolicy"; // 比特率设置 [AUTO-TRANSLATED:2c75f5bc] // Bitrate setting @@ -67,10 +85,12 @@ const string kMinBitrate = RTC_FIELD "min_bitrate"; // 数据通道设置 [AUTO-TRANSLATED:2dc48bc3] // Data channel setting const string kDataChannelEcho = RTC_FIELD "datachannel_echo"; +const string kPreferredTcp = RTC_FIELD "preferred_tcp"; static onceToken token([]() { mINI::Instance()[kTimeOutSec] = 15; mINI::Instance()[kExternIP] = ""; + mINI::Instance()[kInterfaces] = ""; mINI::Instance()[kRembBitRate] = 0; mINI::Instance()[kPort] = 8000; mINI::Instance()[kTcpPort] = 8000; @@ -80,28 +100,23 @@ static onceToken token([]() { mINI::Instance()[kMinBitrate] = 0; mINI::Instance()[kDataChannelEcho] = true; - mINI::Instance()[kTranscodeG711] = 0; + mINI::Instance()[kTranscodeG711] = 0; + + mINI::Instance()[kSignalingPort] = 3000; + mINI::Instance()[kSignalingSslPort] = 3001; + mINI::Instance()[kIcePort] = 3478; + mINI::Instance()[kIceTcpPort] = 3478; + mINI::Instance()[kEnableTurn] = 1; + mINI::Instance()[kIceTransportPolicy] = 0; // 默认值:不限制(kAll) + mINI::Instance()[kIceUfrag] = "ZLMediaKit"; + mINI::Instance()[kIcePwd] = "ZLMediaKit"; + mINI::Instance()[kPreferredTcp] = 0; }); -} // namespace RTC +} // namespace Rtc static atomic s_key { 0 }; -static void translateIPFromEnv(std::vector &v) { - for (auto iter = v.begin(); iter != v.end();) { - if (start_with(*iter, "$")) { - auto ip = toolkit::getEnv(*iter); - if (ip.empty()) { - iter = v.erase(iter); - } else { - *iter++ = ip; - } - } else { - ++iter; - } - } -} - static std::string getServerPrefix() { // stun_user_name格式: base64(ip+udp_port+tcp_port) + _ + number [AUTO-TRANSLATED:cc3c5902] // stun_user_name format: base64(ip+udp_port+tcp_port) + _ + number @@ -131,17 +146,135 @@ static std::string getServerPrefix() { return ret; } -const char* sockTypeStr(Session* session) { - if (session) { - switch (session->getSock()->sockType()) { - case SockNum::Sock_TCP: return "tcp"; - case SockNum::Sock_UDP: return "udp"; - default: break; +static std::string mappingCandidateTypeEnum2Str(CandidateInfo::AddressType type) { + switch (type) { + case CandidateInfo::AddressType::HOST: return "host"; + case CandidateInfo::AddressType::SRFLX: return "srflx"; + case CandidateInfo::AddressType::PRFLX: return "prflx"; + case CandidateInfo::AddressType::RELAY: return "relay"; + default: break; + } + return "invalid"; +} + +static CandidateInfo::AddressType mappingCandidateTypeStr2Enum(const std::string &type) { + if (strcasecmp(type.c_str(), "host") == 0) { + return CandidateInfo::AddressType::HOST; + } + if (strcasecmp(type.c_str(), "srflx") == 0) { + return CandidateInfo::AddressType::SRFLX; + } + if (strcasecmp(type.c_str(), "prflx") == 0) { + return CandidateInfo::AddressType::PRFLX; + } + if (strcasecmp(type.c_str(), "relay") == 0) { + return CandidateInfo::AddressType::RELAY; + } + return CandidateInfo::AddressType::INVALID; +} + +// 根据RFC 5245标准计算foundation +// 1. IP地址类型(IPv4/IPv6) +// 2. 传输协议(UDP/TCP) +// 3. 候选类型(host/srflx/prflx/relay) +// 4. STUN/TURN服务器地址(对于srflx和relay类型) +static std::string calculateFoundation(const std::string& ip, const std::string& proto, const std::string& type, const std::string& stun_server = "") { + // 将协议和类型转换为小写以确保一致性 + std::string proto_lower = proto; + std::string type_lower = type; + std::transform(proto_lower.begin(), proto_lower.end(), proto_lower.begin(), ::tolower); + std::transform(type_lower.begin(), type_lower.end(), type_lower.begin(), ::tolower); + + std::string foundation_base = type_lower + "-" + ip + "-" + proto_lower; + + // 对于server reflexive和relay候选,需要包含STUN/TURN服务器地址 + if ((type_lower == "srflx" || type_lower == "relay") && !stun_server.empty()) { + foundation_base += "-" + stun_server; + } + + std::hash hasher; + size_t hash_value = hasher(foundation_base); + char foundation_str[9]; + snprintf(foundation_str, sizeof(foundation_str), "%08x", (unsigned int)(hash_value & 0xFFFFFFFF)); + return foundation_str; +} + +static SdpAttrCandidate::Ptr makeIceCandidate(std::string ip, uint16_t port, uint32_t priority = 100, + const std::string &proto = "udp", const std::string &type = "host", + const std::string &base_host = "", uint16_t base_port = 0, const std::string &stun_server = "") { + auto candidate = std::make_shared(); + candidate->foundation = calculateFoundation(ip, proto, type, stun_server); + candidate->component = 1; + candidate->transport = proto; + candidate->priority = priority; + candidate->address = std::move(ip); + candidate->port = port; + candidate->type = type; + if (strcasecmp(proto.c_str(), "tcp") == 0) { + candidate->type += " tcptype passive"; + } + + if (type != "host" && !base_host.empty() && base_port > 0) { + candidate->arr.emplace_back("raddr", base_host); + candidate->arr.emplace_back("rport", std::to_string(base_port)); + } + + return candidate; +} + +static CandidateInfo::Ptr makeCandidateInfoBySdpAttr(const SdpAttrCandidate& candidate_attr, const std::string& ufrag, const std::string& pwd) { + auto candidate = std::make_shared(); + candidate->_type = mappingCandidateTypeStr2Enum(candidate_attr.type); + candidate->_priority = candidate_attr.priority; + + candidate->_addr._host = candidate_attr.address; + candidate->_addr._port = candidate_attr.port; + candidate->_base_addr._host = candidate->_addr._host; + candidate->_base_addr._port = candidate->_addr._port; + candidate->_priority = candidate_attr.priority; + candidate->_ufrag = ufrag; + candidate->_pwd = pwd; + + if (CandidateInfo::AddressType::HOST == candidate->_type) { + candidate->_base_addr = candidate->_addr; + } else { + for (auto &pr : candidate_attr.arr) { + if (pr.first == "raddr") { + candidate->_base_addr._host = pr.second; + } + if (pr.first == "rport") { + candidate->_base_addr._port = atoi(pr.second.data()); + } } } - return "unknown"; + + if (strcasecmp(candidate_attr.transport.c_str(), "udp") == 0) { + candidate->_transport = CandidateTuple::TransportType::UDP; + candidate->_secure = CandidateTuple::SecureType::NOT_SECURE; + } else if (strcasecmp(candidate_attr.transport.c_str(), "tcp") == 0) { + candidate->_transport = CandidateTuple::TransportType::TCP; + candidate->_secure = CandidateTuple::SecureType::NOT_SECURE; + } + + return candidate; } +const char* WebRtcTransport::SignalingProtocolsStr(SignalingProtocols protocol) { + switch (protocol) { + case SignalingProtocols::WHEP_WHIP: return "whep_whip"; + case SignalingProtocols::WEBSOCKET: return "websocket"; + default: return "invalid"; + } +} + +const char* WebRtcTransport::RoleStr(Role role) { + switch (role) { + case Role::CLIENT: return "client"; + case Role::PEER: return "peer"; + default: return "none"; + } +} + WebRtcTransport::WebRtcTransport(const EventPoller::Ptr &poller) { _poller = poller; static auto prefix = getServerPrefix(); @@ -151,7 +284,18 @@ WebRtcTransport::WebRtcTransport(const EventPoller::Ptr &poller) { void WebRtcTransport::onCreate() { _dtls_transport = std::make_shared(_poller, this); - _ice_server = std::make_shared(this, _identifier, makeRandStr(24)); + IceAgent::Role role = IceAgent::Role::Controlling; + IceAgent::Implementation implementation = IceAgent::Implementation::Full; + + if (_role == Role::PEER) { + role = IceAgent::Role::Controlled; + if (_signaling_protocols == SignalingProtocols::WHEP_WHIP) { + implementation = IceAgent::Implementation::Lite; + } + } + + _ice_agent = std::make_shared(this, implementation, role, _identifier, makeRandStr(24), getPoller()); + _ice_agent->initialize(); } void WebRtcTransport::onDestory() { @@ -159,56 +303,144 @@ void WebRtcTransport::onDestory() { _sctp = nullptr; #endif _dtls_transport = nullptr; - _ice_server = nullptr; -} - -const EventPoller::Ptr &WebRtcTransport::getPoller() const { - return _poller; + _ice_agent = nullptr; } const string &WebRtcTransport::getIdentifier() const { return _identifier; } -const std::string& WebRtcTransport::deleteRandStr() const { +const std::string &WebRtcTransport::deleteRandStr() const { if (_delete_rand_str.empty()) { _delete_rand_str = makeRandStr(32); } return _delete_rand_str; } -////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void WebRtcTransport::getTransportInfo(const std::function& callback) const { + if (!callback) { + return; + } -void WebRtcTransport::OnIceServerSendStunPacket( - const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) { - sendSockData((char *)packet->GetData(), packet->GetSize(), tuple); + std::weak_ptr weak_self = shared_from_this(); + _poller->async([weak_self, callback]() { + Json::Value result; + auto strong_self = weak_self.lock(); + if (!strong_self) { + result["error"] = "Transport object destroyed"; + callback(std::move(result)); + return; + } + + try { + result["transport_id"] = strong_self->_identifier; + result["role"] = RoleStr(strong_self->_role); + result["signaling_protocol"] = SignalingProtocolsStr(strong_self->_signaling_protocols); + + result["has_offer_sdp"] = (strong_self->_offer_sdp != nullptr); + result["has_answer_sdp"] = (strong_self->_answer_sdp != nullptr); + result["dtls_state"] = strong_self->_dtls_transport? "connected" : "disconnected"; + result["srtp_send_ready"] = (strong_self->_srtp_session_send != nullptr); + result["srtp_recv_ready"] = (strong_self->_srtp_session_recv != nullptr); + + // ICE 连接检查列表信息 + if (strong_self->_ice_agent) { + Json::Value ice_info = strong_self->_ice_agent->getChecklistInfo(); + result["ice_checklists"] = ice_info; + } else { + result["ice_checklists"] = Json::nullValue; + } + + + } catch (const std::exception& ex) { + result["error"] = std::string("Exception occurred: ") + ex.what(); + } + + callback(std::move(result)); + }); } -void WebRtcTransportImp::OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) { - InfoL << getIdentifier() << " select tuple " << sockTypeStr(tuple) << " " << tuple->get_peer_ip() << ":" << tuple->get_peer_port(); - tuple->setSendFlushFlag(false); - unrefSelf(); +void WebRtcTransport::gatheringCandidate(IceServerInfo::Ptr ice_server, onGatheringCandidateCB cb) { + _on_gathering_candidate = std::move(cb); + _ice_agent->setIceServer(ice_server); + return _ice_agent->gatheringCandidate(ice_server, true, ice_server->_schema == IceServerInfo::SchemaType::TURN); } -void WebRtcTransport::OnIceServerConnected(const RTC::IceServer *iceServer) { +void WebRtcTransport::connectivityCheck(SdpAttrCandidate candidate_attr, const std::string& ufrag, const std::string& pwd) { + DebugL; + auto candidate = makeCandidateInfoBySdpAttr(candidate_attr, ufrag, pwd); + return _ice_agent->connectivityCheck(*candidate); +} + +void WebRtcTransport::connectivityCheckForSFU() { + DebugL; + // Connectivity Checks 连通性测试 + + auto answer_sdp = answerSdp(); + // TODO: 暂不支持每个媒体源,RTP,RTCP独立的candidates + for (auto &media : answer_sdp->media) { + for (auto &item : media.candidate) { + auto candidate = makeCandidateInfoBySdpAttr(item, media.ice_ufrag, media.ice_pwd); + _ice_agent->gatheringCandidate(candidate, false, false); + _ice_agent->connectivityCheck(*candidate); + } + } +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +void WebRtcTransport::onIceTransportCompleted() { InfoL << getIdentifier(); -} - -void WebRtcTransport::OnIceServerCompleted(const RTC::IceServer *iceServer) { - InfoL << getIdentifier(); - if (_answer_sdp->media[0].role == DtlsRole::passive) { + + if (!_answer_sdp) { + onShutdown(SockException(Err_other, "answer sdp not ready")); + return; + } + + _recv_ticker.resetTime(); + auto timeout = getTimeOutSec(); + weak_ptr weakSelf = static_pointer_cast(shared_from_this()); + _check_timer = std::make_shared(timeout / 2, [weakSelf, timeout]() { + auto strongSelf = weakSelf.lock(); + if (!strongSelf) { + return false; + } + if (strongSelf->_recv_ticker.elapsedTime() > timeout * 1000) { + // 接收媒体数据包超时 + strongSelf->onShutdown(SockException(Err_timeout, "webrtc data receive timeout")); + return false; + } + + return true; + }, getPoller()); + + if ((getRole() == Role::PEER && _answer_sdp->media[0].role == DtlsRole::passive) + || (getRole() == Role::CLIENT && _answer_sdp->media[0].role == DtlsRole::active)) { _dtls_transport->Run(RTC::DtlsTransport::Role::SERVER); } else { _dtls_transport->Run(RTC::DtlsTransport::Role::CLIENT); } } -void WebRtcTransport::OnIceServerDisconnected(const RTC::IceServer *iceServer) { +void WebRtcTransport::onIceTransportDisconnected() { InfoL << getIdentifier(); } +void WebRtcTransport::onIceTransportGatheringCandidate(const IceTransport::Pair::Ptr &pair, const CandidateInfo &candidate) { + InfoL << getIdentifier() << " get local candidate type " << candidate.dumpString(); + if (_on_gathering_candidate) { + auto type = mappingCandidateTypeEnum2Str(candidate._type); + auto sdpAttrCandidate = makeIceCandidate(candidate._addr._host, candidate._addr._port, candidate._priority, "udp", type, candidate._base_addr._host, candidate._base_addr._port); + _on_gathering_candidate(getIdentifier(), sdpAttrCandidate->toString(), candidate._ufrag, candidate._pwd); + } +} + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +void WebRtcTransport::setOnStartWebRTC(std::function on_start) { + _on_start = std::move(on_start); +} + void WebRtcTransport::OnDtlsTransportConnected( const RTC::DtlsTransport *dtlsTransport, RTC::SrtpSession::CryptoSuite srtpCryptoSuite, uint8_t *srtpLocalKey, size_t srtpLocalKeyLen, uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) { @@ -222,6 +454,9 @@ void WebRtcTransport::OnDtlsTransportConnected( _sctp->TransportConnected(); #endif onStartWebRTC(); + if (_on_start) { + _on_start(); + } } #pragma pack(push, 1) @@ -235,13 +470,12 @@ struct DtlsHeader { }; #pragma pack(pop) -void WebRtcTransport::OnDtlsTransportSendData( - const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) { +void WebRtcTransport::OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) { size_t offset = 0; - while(offset < len) { + while (offset < len) { auto *header = reinterpret_cast(data + offset); auto length = ntohs(header->length) + offsetof(DtlsHeader, payload); - sendSockData((char *)data + offset, length, nullptr); + sendSockData((char *)data + offset, length); offset += length; } } @@ -352,17 +586,39 @@ void WebRtcTransport::sendDatachannel(uint16_t streamId, uint32_t ppid, const ch ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -void WebRtcTransport::sendSockData(const char *buf, size_t len, RTC::TransportTuple *tuple) { +void WebRtcTransport::sendSockData(const char *buf, size_t len, const IceTransport::Pair::Ptr &pair) { auto pkt = _packet_pool.obtain2(); pkt->assign(buf, len); - onSendSockData(std::move(pkt), true, tuple ? tuple : _ice_server->GetSelectedTuple()); + onSendSockData(std::move(pkt), true, pair); } Session::Ptr WebRtcTransport::getSession() const { - auto tuple = _ice_server ? _ice_server->GetSelectedTuple(true) : nullptr; - return tuple ? static_pointer_cast(tuple->shared_from_this()) : nullptr; + auto pair = _ice_agent->getSelectedPair(); + return pair ? static_pointer_cast(pair->_socket->shared_from_this()) : nullptr; } +void WebRtcTransport::removePair(const SocketHelper *socket) { + _ice_agent->removePair(socket); +} + +void WebRtcTransport::setOnShutdown(function cb) { + _on_shutdown = cb ? std::move(cb) : [](const SockException &) {}; +} + +void WebRtcTransport::onShutdown(const SockException &ex) { + TraceL << ex; + if (_on_shutdown) { + _on_shutdown(ex); + } + if (_ice_agent) { + for (auto &pair : _ice_agent->getPairs()) { + if (pair->_socket) { + pair->_socket->shutdown(ex); + } + } + } +} + void WebRtcTransport::sendRtcpRemb(uint32_t ssrc, size_t bit_rate) { auto remb = FCI_REMB::create({ ssrc }, (uint32_t)bit_rate); auto fb = RtcpFB::create(PSFBType::RTCP_PSFB_REMB, remb.data(), remb.size()); @@ -388,22 +644,21 @@ string getFingerprint(const string &algorithm_str, const std::shared_ptrmedia[0] : _offer_sdp->media[0]; RTC::DtlsTransport::Fingerprint remote_fingerprint; - remote_fingerprint.algorithm - = RTC::DtlsTransport::GetFingerprintAlgorithm(_offer_sdp->media[0].fingerprint.algorithm); - remote_fingerprint.value = _offer_sdp->media[0].fingerprint.hash; + remote_fingerprint.algorithm = RTC::DtlsTransport::GetFingerprintAlgorithm(media.fingerprint.algorithm); + remote_fingerprint.value = media.fingerprint.hash; _dtls_transport->SetRemoteFingerprint(remote_fingerprint); } void WebRtcTransport::onRtcConfigure(RtcConfigure &configure) const { SdpAttrFingerprint fingerprint; - fingerprint.algorithm = _offer_sdp->media[0].fingerprint.algorithm; + fingerprint.algorithm = _offer_sdp ? _offer_sdp->media[0].fingerprint.algorithm : "sha-256"; fingerprint.hash = getFingerprint(fingerprint.algorithm, _dtls_transport); - configure.setDefaultSetting( - _ice_server->GetUsernameFragment(), _ice_server->GetPassword(), RtpDirection::sendrecv, fingerprint); + configure.setDefaultSetting(_ice_agent->getUfrag(), _ice_agent->getPassword(), RtpDirection::sendrecv, fingerprint); // 开启remb后关闭twcc,因为开启twcc后remb无效 [AUTO-TRANSLATED:8a8feca2] // Turn off twcc after turning on remb, because remb is invalid after turning on twcc @@ -425,6 +680,18 @@ static void setSdpBitrate(RtcSession &sdp) { } } +std::string WebRtcTransport::createOfferSdp() { + try { + RtcConfigure configure; + onRtcConfigure(configure); + _offer_sdp = configure.createOffer(); + return _offer_sdp->toString(); + } catch (exception &ex) { + onShutdown(SockException(Err_shutdown, ex.what())); + throw; + } +} + std::string WebRtcTransport::getAnswerSdp(const string &offer) { try { // // 解析offer sdp //// [AUTO-TRANSLATED:87c1f337] @@ -433,7 +700,7 @@ std::string WebRtcTransport::getAnswerSdp(const string &offer) { _offer_sdp->loadFrom(offer); onCheckSdp(SdpType::offer, *_offer_sdp); _offer_sdp->checkValid(); - setRemoteDtlsFingerprint(*_offer_sdp); + setRemoteDtlsFingerprint(SdpType::offer, *_offer_sdp); // // sdp 配置 //// [AUTO-TRANSLATED:718a72e2] // // sdp configuration //// @@ -453,18 +720,39 @@ std::string WebRtcTransport::getAnswerSdp(const string &offer) { } } -static bool isDtls(char *buf) { +void WebRtcTransport::setAnswerSdp(const std::string &answer) { + try { + _answer_sdp = std::make_shared(); + _answer_sdp->loadFrom(answer); + onCheckSdp(SdpType::answer, *_answer_sdp); + _answer_sdp->checkValid(); + setRemoteDtlsFingerprint(SdpType::answer, *_answer_sdp); + } catch (exception &ex) { + onShutdown(SockException(Err_shutdown, ex.what())); + throw; + } +} + +static bool isDtls(const char *buf) { return ((*buf > 19) && (*buf < 64)); } -void WebRtcTransport::inputSockData(char *buf, int len, RTC::TransportTuple *tuple) { - if (RTC::StunPacket::IsStun((const uint8_t *)buf, len)) { - std::unique_ptr packet(RTC::StunPacket::Parse((const uint8_t *)buf, len)); - if (!packet) { - WarnL << "parse stun error"; - return; - } - _ice_server->ProcessStunPacket(packet.get(), tuple); +void WebRtcTransport::inputSockData(const char *buf, int len, const SocketHelper::Ptr& socket, struct sockaddr *addr, int addr_len) { + IceTransport::Pair::Ptr pair; + if (addr != nullptr) { + auto peer_host = SockUtil::inet_ntoa(addr); + auto peer_port = SockUtil::inet_port(addr); + pair = std::make_shared(socket, std::move(peer_host), peer_port); + } else { + pair = std::make_shared(socket); + } + return inputSockData(buf, len, pair); +} + +void WebRtcTransport::inputSockData(const char *buf, int len, const IceTransport::Pair::Ptr& pair) { + // DebugL; + _recv_ticker.resetTime(); + if (_ice_agent->processSocketData((const uint8_t *)buf, len, pair)) { return; } if (isDtls(buf)) { @@ -473,7 +761,7 @@ void WebRtcTransport::inputSockData(char *buf, int len, RTC::TransportTuple *tup } if (isRtp(buf, len)) { if (!_srtp_session_recv) { - WarnL << "received rtp packet when dtls not completed from:" << tuple->get_peer_ip(); + WarnL << "received rtp packet when dtls not completed from:" << pair->get_peer_ip(); return; } if (_srtp_session_recv->DecryptSrtp((uint8_t *)buf, &len)) { @@ -483,7 +771,7 @@ void WebRtcTransport::inputSockData(char *buf, int len, RTC::TransportTuple *tup } if (isRtcp(buf, len)) { if (!_srtp_session_recv) { - WarnL << "received rtcp packet when dtls not completed from:" << tuple->get_peer_ip(); + WarnL << "received rtcp packet when dtls not completed from:" << pair->get_peer_ip(); return; } if (_srtp_session_recv->DecryptSrtcp((uint8_t *)buf, &len)) { @@ -570,38 +858,26 @@ void WebRtcTransportImp::onDestory() { unregisterSelf(); } -void WebRtcTransportImp::onSendSockData(Buffer::Ptr buf, bool flush, RTC::TransportTuple *tuple) { - if (tuple == nullptr) { - tuple = _ice_server->GetSelectedTuple(); - if (!tuple) { - WarnL << "send data failed:" << buf->size(); - return; - } - } +void WebRtcTransportImp::onSendSockData(Buffer::Ptr buf, bool flush, const IceTransport::Pair::Ptr& pair) { + return _ice_agent->sendSocketData(buf, pair, flush); +} - // 一次性发送一帧的rtp数据,提高网络io性能 [AUTO-TRANSLATED:fbab421e] - // Send one frame of rtp data at a time to improve network io performance - if (tuple->getSock()->sockType() == SockNum::Sock_TCP) { - // 增加tcp两字节头 [AUTO-TRANSLATED:62159f79] - // Add two-byte header to tcp - auto len = buf->size(); - char tcp_len[2] = { 0 }; - tcp_len[0] = (len >> 8) & 0xff; - tcp_len[1] = len & 0xff; - tuple->SockSender::send(tcp_len, 2); - } - tuple->send(std::move(buf)); +/////////////////////////////////////////////////////////////////// +bool WebRtcTransportImp::canSendRtp(const RtcMedia& m) const { + return (getRole() == WebRtcTransport::Role::PEER && m.direction == RtpDirection::sendonly) + || (getRole() == WebRtcTransport::Role::CLIENT && m.direction == RtpDirection::recvonly) + || (m.direction == RtpDirection::sendrecv); +} - if (flush) { - tuple->flushAll(); - } +bool WebRtcTransportImp::canRecvRtp(const RtcMedia& m) const { + return (getRole() == WebRtcTransport::Role::PEER && m.direction == RtpDirection::recvonly) + || (getRole() == WebRtcTransport::Role::CLIENT && m.direction == RtpDirection::sendonly) + || (m.direction == RtpDirection::sendrecv); } -/////////////////////////////////////////////////////////////////// - bool WebRtcTransportImp::canSendRtp() const { for (auto &m : _answer_sdp->media) { - if (m.direction == RtpDirection::sendrecv || m.direction == RtpDirection::sendonly) { + if (canSendRtp(m)) { return true; } } @@ -610,7 +886,7 @@ bool WebRtcTransportImp::canSendRtp() const { bool WebRtcTransportImp::canRecvRtp() const { for (auto &m : _answer_sdp->media) { - if (m.direction == RtpDirection::sendrecv || m.direction == RtpDirection::recvonly) { + if (canRecvRtp(m)) { return true; } } @@ -637,7 +913,7 @@ void WebRtcTransportImp::onStartWebRTC() { track->rtcp_context_send = std::make_shared(); // rtp track type --> MediaTrack - if (m_answer.direction == RtpDirection::sendonly || m_answer.direction == RtpDirection::sendrecv) { + if (canSendRtp(m_answer)) { // 该类型的track 才支持发送 [AUTO-TRANSLATED:b7c1e631] // This type of track supports sending _type_to_track[m_answer.type] = track; @@ -767,25 +1043,6 @@ void WebRtcTransportImp::onCheckSdp(SdpType type, RtcSession &sdp) { } } -SdpAttrCandidate::Ptr -makeIceCandidate(std::string ip, uint16_t port, uint32_t priority = 100, std::string proto = "udp") { - auto candidate = std::make_shared(); - // rtp端口 [AUTO-TRANSLATED:b0addb27] - // rtp port - candidate->component = 1; - candidate->transport = proto; - candidate->foundation = proto + "candidate"; - // 优先级,单candidate时随便 [AUTO-TRANSLATED:7c85d820] - // Priority, random when there is only one candidate - candidate->priority = priority; - candidate->address = std::move(ip); - candidate->port = port; - candidate->type = "host"; - if (proto == "tcp") { - candidate->type += " tcptype passive"; - } - return candidate; -} void WebRtcTransportImp::onRtcConfigure(RtcConfigure &configure) const { WebRtcTransport::onRtcConfigure(configure); @@ -796,29 +1053,41 @@ void WebRtcTransportImp::onRtcConfigure(RtcConfigure &configure) const { return; } - GET_CONFIG(uint16_t, local_udp_port, Rtc::kPort); - GET_CONFIG(uint16_t, local_tcp_port, Rtc::kTcpPort); - // 添加接收端口candidate信息 [AUTO-TRANSLATED:cc9a6a90] - // Add the receiving port candidate information - GET_CONFIG_FUNC(std::vector, extern_ips, Rtc::kExternIP, [](string str) { - std::vector ret; - if (str.length()) { - ret = split(str, ","); - } - translateIPFromEnv(ret); - return ret; - }); - if (extern_ips.empty()) { - std::string local_ip = _local_ip.empty() ? SockUtil::get_local_ip() : _local_ip; - if (local_udp_port) { configure.addCandidate(*makeIceCandidate(local_ip, local_udp_port, 120, "udp")); } - if (local_tcp_port) { configure.addCandidate(*makeIceCandidate(local_ip, local_tcp_port, _preferred_tcp ? 125 : 115, "tcp")); } - } else { - const uint32_t delta = 10; - uint32_t priority = 100 + delta * extern_ips.size(); - for (auto ip : extern_ips) { - if (local_udp_port) { configure.addCandidate(*makeIceCandidate(ip, local_udp_port, priority, "udp")); } - if (local_tcp_port) { configure.addCandidate(*makeIceCandidate(ip, local_tcp_port, priority - (_preferred_tcp ? -5 : 5), "tcp")); } - priority -= delta; + //P2P的不直接在answer中返回candication + if (getSignalingProtocols() == SignalingProtocols::WHEP_WHIP) { + + GET_CONFIG(uint16_t, local_udp_port, Rtc::kPort); + GET_CONFIG(uint16_t, local_tcp_port, Rtc::kTcpPort); + // 添加接收端口candidate信息 [AUTO-TRANSLATED:cc9a6a90] + // Add the receiving port candidate information + GET_CONFIG_FUNC(std::vector, extern_ips, Rtc::kExternIP, [](string str) { + std::vector ret; + if (str.length()) { + ret = split(str, ","); + } + translateIPFromEnv(ret); + return ret; + }); + if (extern_ips.empty()) { + std::string local_ip = _local_ip.empty() ? SockUtil::get_local_ip() : _local_ip; + if (local_udp_port) { + configure.addCandidate(*makeIceCandidate(local_ip, local_udp_port, 120, "udp")); + } + if (local_tcp_port) { + configure.addCandidate(*makeIceCandidate(local_ip, local_tcp_port, _preferred_tcp ? 125 : 115, "tcp")); + } + } else { + const uint32_t delta = 10; + uint32_t priority = 100 + delta * extern_ips.size(); + for (auto ip : extern_ips) { + if (local_udp_port) { + configure.addCandidate(*makeIceCandidate(ip, local_udp_port, priority, "udp")); + } + if (local_tcp_port) { + configure.addCandidate(*makeIceCandidate(ip, local_tcp_port, priority - (_preferred_tcp ? -5 : 5), "tcp")); + } + priority -= delta; + } } } } @@ -839,7 +1108,7 @@ void WebRtcTransportImp::setIceCandidate(vector cands) { class RtpChannel : public RtpTrackImp, public std::enable_shared_from_this { public: - RtpChannel(EventPoller::Ptr poller, RtpTrackImp::OnSorted cb, function on_nack) { + RtpChannel(TrackType type, EventPoller::Ptr poller, RtpTrackImp::OnSorted cb, function on_nack) : _nack_ctx(type){ _poller = std::move(poller); _on_nack = std::move(on_nack); setOnSorted(std::move(cb)); @@ -865,9 +1134,10 @@ public: } return rtp; } - - Buffer::Ptr createRtcpRR(RtcpHeader *sr, uint32_t ssrc) { + void onRtcp(RtcpHeader *sr) { _rtcp_context.onRtcp(sr); + } + Buffer::Ptr createRtcpRR(uint32_t ssrc) { return _rtcp_context.createRtcpRR(ssrc, getSSRC()); } @@ -953,8 +1223,7 @@ void WebRtcTransportImp::onRtcp(const char *buf, size_t len) { // 设置rtp时间戳与ntp时间戳的对应关系 [AUTO-TRANSLATED:e92f4749] // Set the correspondence between rtp timestamp and ntp timestamp rtp_chn->setNtpStamp(sr->rtpts, sr->getNtpUnixStampMS()); - auto rr = rtp_chn->createRtcpRR(sr, track->answer_ssrc_rtp); - sendRtcpPacket(rr->data(), rr->size(), true); + rtp_chn->onRtcp(sr); } } else { WarnL << "未识别的sr rtcp包:" << rtcp->dumpString(); @@ -971,8 +1240,6 @@ void WebRtcTransportImp::onRtcp(const char *buf, size_t len) { if (it != _ssrc_to_track.end()) { auto &track = it->second; track->rtcp_context_send->onRtcp(rtcp); - auto sr = track->rtcp_context_send->createRtcpSR(track->answer_ssrc_rtp); - sendRtcpPacket(sr->data(), sr->size(), true); } else { WarnL << "未识别的rr rtcp包:" << rtcp->dumpString(); } @@ -1053,7 +1320,7 @@ void WebRtcTransportImp::createRtpChannel(const string &rid, uint32_t ssrc, Medi // rid --> RtpReceiverImp auto &ref = track.rtp_channel[rid]; weak_ptr weak_self = static_pointer_cast(shared_from_this()); - ref = std::make_shared( + ref = std::make_shared(track.media->type, getPoller(), [&track, this, rid](RtpPacket::Ptr rtp) mutable { onSortedRtp(track, rid, std::move(rtp)); }, [&track, weak_self, ssrc](const FCI_NACK &nack) mutable { // nack发送可能由定时器异步触发 [AUTO-TRANSLATED:186d6723] @@ -1082,6 +1349,22 @@ void WebRtcTransportImp::onRtp(const char *buf, size_t len, uint64_t stamp_ms) { WarnL << "unknown rtp pt:" << (int)rtp->pt; return; } + + if (_rtcp_rr_send_ticker.elapsedTime() > 5000) { + _rtcp_rr_send_ticker.resetTime(); + for (auto& it : _ssrc_to_track) { + auto ssrc = it.first; + auto &track = it.second; + auto rtp_chn = track->getRtpChannel(ssrc); + if (rtp_chn) { + auto rr = rtp_chn->createRtcpRR(track->answer_ssrc_rtp); + if (rr && rr->size() > 0) { + sendRtcpPacket(rr->data(), rr->size(), true); + } + } + } + } + it->second->inputRtp(buf, len, stamp_ms, rtp); } @@ -1220,6 +1503,16 @@ void WebRtcTransportImp::onSendRtp(const RtpPacket::Ptr &rtp, bool flush, bool r pair ctx { rtx, track.get() }; sendRtpPacket(rtp->data() + RtpPacket::kRtpTcpHeaderSize, rtp->size() - RtpPacket::kRtpTcpHeaderSize, flush, &ctx); _bytes_usage += rtp->size() - RtpPacket::kRtpTcpHeaderSize; + + if (_rtcp_sr_send_ticker.elapsedTime() > 5000) { + _rtcp_sr_send_ticker.resetTime(); + if (track->rtcp_context_send) { + auto sr = track->rtcp_context_send->createRtcpSR(track->answer_ssrc_rtp); + if (sr && sr->size() > 0) { + sendRtcpPacket(sr->data(), sr->size(), true); + } + } + } } void WebRtcTransportImp::onBeforeEncryptRtp(const char *buf, int &len, void *ctx) { @@ -1278,15 +1571,8 @@ void WebRtcTransportImp::safeShutdown(const SockException &ex) { void WebRtcTransportImp::onShutdown(const SockException &ex) { WarnL << ex; + WebRtcTransport::onShutdown(ex); unrefSelf(); - for (auto &tuple : _ice_server->GetTuples()) { - tuple->shutdown(ex); - } -} - -void WebRtcTransportImp::removeTuple(RTC::TransportTuple *tuple) { - InfoL << getIdentifier() << " remove tuple " << tuple->get_peer_ip() << ":" << tuple->get_peer_port(); - this->_ice_server->RemoveTuple(tuple); } uint64_t WebRtcTransportImp::getBytesUsage() const { @@ -1311,6 +1597,7 @@ void WebRtcTransportImp::unrefSelf() { } void WebRtcTransportImp::unregisterSelf() { + DebugL; unrefSelf(); WebRtcTransportManager::Instance().removeItem(getIdentifier()); } @@ -1350,23 +1637,23 @@ WebRtcPluginManager &WebRtcPluginManager::Instance() { } void WebRtcPluginManager::registerPlugin(const string &type, Plugin cb) { + InfoL << "Load webrtc plugin:" << type; lock_guard lck(_mtx_creator); _map_creator[type] = std::move(cb); } - void WebRtcPluginManager::setListener(Listener cb) { lock_guard lck(_mtx_creator); _listener = std::move(cb); } -void WebRtcPluginManager::negotiateSdp(Session &sender, const string &type, const WebRtcArgs &args, const onCreateWebRtc &cb_in) { +void WebRtcPluginManager::negotiateSdp(SocketHelper& sender, const string &type, const WebRtcArgs &args, const onCreateWebRtc &cb_in) { onCreateWebRtc cb; lock_guard lck(_mtx_creator); if (_listener) { auto listener = _listener; auto args_ptr = args.shared_from_this(); - auto sender_ptr = static_pointer_cast(sender.shared_from_this()); + auto sender_ptr = static_pointer_cast(sender.shared_from_this()); cb = [listener, sender_ptr, type, args_ptr, cb_in](const WebRtcInterface &rtc) { listener(*sender_ptr, type, *args_ptr, rtc); cb_in(rtc); @@ -1383,11 +1670,12 @@ void WebRtcPluginManager::negotiateSdp(Session &sender, const string &type, cons it->second(sender, args, cb); } -void echo_plugin(Session &sender, const WebRtcArgs &args, const onCreateWebRtc &cb) { +void echo_plugin(SocketHelper& sender, const WebRtcArgs &args, const onCreateWebRtc &cb) { cb(*WebRtcEchoTest::create(EventPollerPool::Instance().getPoller())); } -void push_plugin(Session &sender, const WebRtcArgs &args, const onCreateWebRtc &cb) { +template +void push_plugin(SocketHelper& sender, const WebRtcArgs &args, const onCreateWebRtc &cb) { MediaInfo info(args["url"]); Broadcast::PublishAuthInvoker invoker = [cb, info](const string &err, const ProtocolOption &option) mutable { if (!err.empty()) { @@ -1431,7 +1719,8 @@ void push_plugin(Session &sender, const WebRtcArgs &args, const onCreateWebRtc & push_src_ownership = push_src->getOwnership(); push_src->setProtocolOption(option); } - auto rtc = WebRtcPusher::create(EventPollerPool::Instance().getPoller(), push_src, push_src_ownership, info, option); + auto rtc = Type::create(EventPollerPool::Instance().getPoller(), push_src, push_src_ownership, info, option, + WebRtcTransport::Role::PEER, WebRtcTransport::SignalingProtocols::WHEP_WHIP); push_src->setListener(rtc); cb(*rtc); }; @@ -1446,7 +1735,9 @@ void push_plugin(Session &sender, const WebRtcArgs &args, const onCreateWebRtc & } } -void play_plugin(Session &sender, const WebRtcArgs &args, const onCreateWebRtc &cb) { +template +void play_plugin(SocketHelper &sender, const WebRtcArgs &args, const onCreateWebRtc &cb) { + MediaInfo info(args["url"]); auto session_ptr = static_pointer_cast(sender.shared_from_this()); Broadcast::AuthInvoker invoker = [cb, info, session_ptr](const string &err) mutable { @@ -1466,7 +1757,8 @@ void play_plugin(Session &sender, const WebRtcArgs &args, const onCreateWebRtc & // 还原成rtc,目的是为了hook时识别哪种播放协议 [AUTO-TRANSLATED:fe8dd2dc] // Restore to RTC, the purpose is to identify which playback protocol during hooking info.schema = "rtc"; - auto rtc = WebRtcPlayer::create(EventPollerPool::Instance().getPoller(), src, info); + auto rtc = Type::create(EventPollerPool::Instance().getPoller(), src, info, + WebRtcTransport::Role::PEER, WebRtcTransport::SignalingProtocols::WHEP_WHIP); cb(*rtc); }); }; @@ -1497,9 +1789,12 @@ static void setWebRtcArgs(const WebRtcArgs &args, WebRtcInterface &rtc) { } } - bool preferred_tcp = args["preferred_tcp"]; - { + auto preferred_tcp = args["preferred_tcp"]; + if (!preferred_tcp.empty()) { rtc.setPreferredTcp(preferred_tcp); + } else { + GET_CONFIG(bool, s_preferred_tcp, Rtc::kPreferredTcp); + rtc.setPreferredTcp(s_preferred_tcp); } { @@ -1532,17 +1827,47 @@ static void setWebRtcArgs(const WebRtcArgs &args, WebRtcInterface &rtc) { } } +float WebRtcTransport::getTimeOutSec() { + GET_CONFIG(uint32_t, timeout, Rtc::kTimeOutSec); + if (timeout <= 0) { + WarnL << "config rtc. " << Rtc::kTimeOutSec << ": " << timeout << " not vaild"; + return 5; + } + return (float)timeout; +} + static onceToken s_rtc_auto_register([]() { #if !defined (NDEBUG) // debug模式才开启echo插件 [AUTO-TRANSLATED:48fcb116] // Enable echo plugin only in debug mode WebRtcPluginManager::Instance().registerPlugin("echo", echo_plugin); #endif - WebRtcPluginManager::Instance().registerPlugin("push", push_plugin); - WebRtcPluginManager::Instance().registerPlugin("play", play_plugin); - WebRtcPluginManager::Instance().setListener([](Session &sender, const std::string &type, const WebRtcArgs &args, const WebRtcInterface &rtc) { + WebRtcPluginManager::Instance().registerPlugin("push", push_plugin); + WebRtcPluginManager::Instance().registerPlugin("play", play_plugin); + WebRtcPluginManager::Instance().registerPlugin("talk", play_plugin); + + WebRtcPluginManager::Instance().setListener([](SocketHelper& sender, const std::string &type, const WebRtcArgs &args, const WebRtcInterface &rtc) { setWebRtcArgs(args, const_cast(rtc)); }); }); +void WebRtcTransport::onIceTransportRecvData(const toolkit::Buffer::Ptr& buffer, const IceTransport::Pair::Ptr& pair) { + return inputSockData(buffer->data(), buffer->size(), pair); +} + +void translateIPFromEnv(std::vector &v) { + for (auto iter = v.begin(); iter != v.end();) { + if (start_with(*iter, "$")) { + auto ip = toolkit::getEnv(*iter); + if (ip.empty()) { + iter = v.erase(iter); + } else { + *iter++ = ip; + } + } else { + ++iter; + } + } +} + }// namespace mediakit diff --git a/webrtc/WebRtcTransport.h b/webrtc/WebRtcTransport.h index a3553c7c..f660f981 100644 --- a/webrtc/WebRtcTransport.h +++ b/webrtc/WebRtcTransport.h @@ -8,12 +8,14 @@ * may be found in the AUTHORS file in the root of the source tree. */ -#pragma once +#ifndef ZLMEDIAKIT_WEBRTC_TRANSPORT_H +#define ZLMEDIAKIT_WEBRTC_TRANSPORT_H #include #include +#include #include "DtlsTransport.hpp" -#include "IceServer.hpp" +#include "IceTransport.hpp" #include "SrtpSession.hpp" #include "StunPacket.hpp" #include "Sdp.h" @@ -25,128 +27,158 @@ #include "TwccContext.h" #include "SctpAssociation.hpp" #include "Rtcp/RtcpContext.h" +#include "Rtsp/RtspMediaSource.h" +using namespace RTC; namespace mediakit { +// ICE transport policy enum +enum class IceTransportPolicy { + kAll = 0, // 不限制,支持所有连接类型(默认) + kRelayOnly = 1, // 仅支持Relay转发 + kP2POnly = 2 // 仅支持P2P直连 +}; + // RTC配置项目 [AUTO-TRANSLATED:65784416] // RTC configuration project namespace Rtc { extern const std::string kPort; extern const std::string kTcpPort; extern const std::string kTimeOutSec; -extern const std::string kTranscodeG711; +extern const std::string kTranscodeG711; +extern const std::string kSignalingPort; +extern const std::string kSignalingSslPort; +extern const std::string kIcePort; +extern const std::string kIceTcpPort; +extern const std::string kEnableTurn; +extern const std::string kIceTransportPolicy; +extern const std::string kIceUfrag; +extern const std::string kIcePwd; +extern const std::string kExternIP; +extern const std::string kInterfaces; }//namespace RTC class WebRtcInterface { public: virtual ~WebRtcInterface() = default; virtual std::string getAnswerSdp(const std::string &offer) = 0; + virtual std::string createOfferSdp() = 0; + virtual void setAnswerSdp(const std::string &answer) = 0; virtual const std::string& getIdentifier() const = 0; virtual const std::string& deleteRandStr() const { static std::string s_null; return s_null; } virtual void setIceCandidate(std::vector cands) {} virtual void setLocalIp(std::string localIp) {} virtual void setPreferredTcp(bool flag) {} + + using onGatheringCandidateCB = std::function; + virtual void gatheringCandidate(IceServerInfo::Ptr ice_server, onGatheringCandidateCB cb = nullptr) = 0; }; class WebRtcException : public WebRtcInterface { public: - WebRtcException(const SockException &ex) : _ex(ex) {}; + WebRtcException(const toolkit::SockException &ex) : _ex(ex) {}; + + std::string createOfferSdp() override { + throw _ex; + } + std::string getAnswerSdp(const std::string &offer) override { throw _ex; } + + void setAnswerSdp(const std::string &answer) override { + throw _ex; + } + + void gatheringCandidate(IceServerInfo::Ptr ice_server, onGatheringCandidateCB cb = nullptr) override { + throw _ex; + } + const std::string &getIdentifier() const override { static std::string s_null; return s_null; } private: - SockException _ex; + toolkit::SockException _ex; }; -class WebRtcTransport : public WebRtcInterface, public RTC::DtlsTransport::Listener, public RTC::IceServer::Listener, public std::enable_shared_from_this +class WebRtcTransport : public WebRtcInterface, public RTC::DtlsTransport::Listener, public IceTransport::Listener, public std::enable_shared_from_this #ifdef ENABLE_SCTP , public RTC::SctpAssociation::Listener #endif { public: + enum class Role { + NONE = 0, + CLIENT, + PEER, + }; + static const char* RoleStr(Role role); + + enum class SignalingProtocols { + Invalid = -1, + WHEP_WHIP = 0, + WEBSOCKET = 1, //FOR P2P + }; + static const char* SignalingProtocolsStr(SignalingProtocols protocol); + + using WeakPtr = std::weak_ptr; using Ptr = std::shared_ptr; - WebRtcTransport(const EventPoller::Ptr &poller); + WebRtcTransport(const toolkit::EventPoller::Ptr &poller); - /** - * 创建对象 - * Create object - - * [AUTO-TRANSLATED:830344e4] - */ virtual void onCreate(); - /** - * 销毁对象 - * Destroy object - - * [AUTO-TRANSLATED:1016b97b] - */ virtual void onDestory(); - /** - * 创建webrtc answer sdp - * @param offer offer sdp - * @return answer sdp - * Create webrtc answer sdp - * @param offer offer sdp - * @return answer sdp - - * [AUTO-TRANSLATED:d9b027d7] - */ - std::string getAnswerSdp(const std::string &offer) override final; + std::string getAnswerSdp(const std::string &offer) override; + void setAnswerSdp(const std::string &answer) override; - /** - * 获取对象唯一id - * Get object unique id - - * [AUTO-TRANSLATED:9ad519c6] - */ + const RtcSession::Ptr& answerSdp() const { + return _answer_sdp; + } + + std::string createOfferSdp() override; + const std::string& getIdentifier() const override; const std::string& deleteRandStr() const override; - /** - * socket收到udp数据 - * @param buf 数据指针 - * @param len 数据长度 - * @param tuple 数据来源 - * Socket receives udp data - * @param buf data pointer - * @param len data length - * @param tuple data source - - * [AUTO-TRANSLATED:1ee86069] - */ - void inputSockData(char *buf, int len, RTC::TransportTuple *tuple); - - /** - * 发送rtp - * @param buf rtcp内容 - * @param len rtcp长度 - * @param flush 是否flush socket - * @param ctx 用户指针 - * Send rtp - * @param buf rtcp content - * @param len rtcp length - * @param flush whether to flush socket - * @param ctx user pointer - - * [AUTO-TRANSLATED:aa833695] - */ + void inputSockData(const char *buf, int len, const toolkit::SocketHelper::Ptr& socket, struct sockaddr *addr = nullptr, int addr_len = 0); + void inputSockData(const char *buf, int len, const IceTransport::Pair::Ptr& pair = nullptr); void sendRtpPacket(const char *buf, int len, bool flush, void *ctx = nullptr); void sendRtcpPacket(const char *buf, int len, bool flush, void *ctx = nullptr); void sendDatachannel(uint16_t streamId, uint32_t ppid, const char *msg, size_t len); - const EventPoller::Ptr& getPoller() const; - Session::Ptr getSession() const; + const toolkit::EventPoller::Ptr &getPoller() const { return _poller; } + void setPoller(toolkit::EventPoller::Ptr poller) { _poller = std::move(poller); } + toolkit::Session::Ptr getSession() const; + void removePair(const toolkit::SocketHelper *socket); + + Role getRole() const { return _role; } + void setRole(Role role) { _role = role; } + + SignalingProtocols getSignalingProtocols() const { return _signaling_protocols; } + void setSignalingProtocols(SignalingProtocols signaling_protocols) { _signaling_protocols = signaling_protocols; } + + float getTimeOutSec(); + + void getTransportInfo(const std::function &callback) const; + size_t getRecvSpeed() const { return _ice_agent ? _ice_agent->getRecvSpeed() : 0; } + size_t getRecvTotalBytes() const { return _ice_agent ? _ice_agent->getRecvTotalBytes() : 0; } + size_t getSendSpeed() const { return _ice_agent ? _ice_agent->getSendSpeed() : 0; } + size_t getSendTotalBytes() const { return _ice_agent ? _ice_agent->getSendTotalBytes() : 0; } + + void setOnShutdown(std::function cb); + + void gatheringCandidate(IceServerInfo::Ptr ice_server, onGatheringCandidateCB cb = nullptr) override; + void connectivityCheck(SdpAttrCandidate candidate_attr, const std::string &ufrag, const std::string &pwd); + void connectivityCheckForSFU(); + + void setOnStartWebRTC(std::function on_start); + protected: - // // dtls相关的回调 //// [AUTO-TRANSLATED:31a1f32c] - // // dtls related callbacks //// + // DtlsTransport::Listener; dtls相关的回调 void OnDtlsTransportConnecting(const RTC::DtlsTransport *dtlsTransport) override; void OnDtlsTransportConnected(const RTC::DtlsTransport *dtlsTransport, RTC::SrtpSession::CryptoSuite srtpCryptoSuite, @@ -155,20 +187,19 @@ protected: uint8_t *srtpRemoteKey, size_t srtpRemoteKeyLen, std::string &remoteCert) override; - void OnDtlsTransportFailed(const RTC::DtlsTransport *dtlsTransport) override; void OnDtlsTransportClosed(const RTC::DtlsTransport *dtlsTransport) override; void OnDtlsTransportSendData(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; protected: - // // ice相关的回调 /// [AUTO-TRANSLATED:30abf693] - // // ice related callbacks /// - void OnIceServerSendStunPacket(const RTC::IceServer *iceServer, const RTC::StunPacket *packet, RTC::TransportTuple *tuple) override; - void OnIceServerConnected(const RTC::IceServer *iceServer) override; - void OnIceServerCompleted(const RTC::IceServer *iceServer) override; - void OnIceServerDisconnected(const RTC::IceServer *iceServer) override; + // ice相关的回调; IceTransport::Listener. + void onIceTransportRecvData(const toolkit::Buffer::Ptr& buffer, const IceTransport::Pair::Ptr& pair) override; + void onIceTransportGatheringCandidate(const IceTransport::Pair::Ptr& pair, const CandidateInfo& candidate) override; + void onIceTransportCompleted() override; + void onIceTransportDisconnected() override; + // SctpAssociation::Listener #ifdef ENABLE_SCTP void OnSctpAssociationConnecting(RTC::SctpAssociation* sctpAssociation) override; void OnSctpAssociationConnected(RTC::SctpAssociation* sctpAssociation) override; @@ -183,11 +214,11 @@ protected: virtual void onStartWebRTC() = 0; virtual void onRtcConfigure(RtcConfigure &configure) const; virtual void onCheckSdp(SdpType type, RtcSession &sdp) = 0; - virtual void onSendSockData(Buffer::Ptr buf, bool flush = true, RTC::TransportTuple *tuple = nullptr) = 0; + virtual void onSendSockData(toolkit::Buffer::Ptr buf, bool flush = true, const IceTransport::Pair::Ptr& pair = nullptr) = 0; virtual void onRtp(const char *buf, size_t len, uint64_t stamp_ms) = 0; virtual void onRtcp(const char *buf, size_t len) = 0; - virtual void onShutdown(const SockException &ex) = 0; + virtual void onShutdown(const toolkit::SockException &ex); virtual void onBeforeEncryptRtp(const char *buf, int &len, void *ctx) = 0; virtual void onBeforeEncryptRtcp(const char *buf, int &len, void *ctx) = 0; virtual void onRtcpBye() = 0; @@ -197,26 +228,36 @@ protected: void sendRtcpPli(uint32_t ssrc); private: - void sendSockData(const char *buf, size_t len, RTC::TransportTuple *tuple); - void setRemoteDtlsFingerprint(const RtcSession &remote); + void sendSockData(const char *buf, size_t len, const IceTransport::Pair::Ptr& pair = nullptr); + void setRemoteDtlsFingerprint(SdpType type, const RtcSession &remote); protected: + SignalingProtocols _signaling_protocols = SignalingProtocols::WHEP_WHIP; + Role _role = Role::PEER; RtcSession::Ptr _offer_sdp; RtcSession::Ptr _answer_sdp; - std::shared_ptr _ice_server; + IceAgent::Ptr _ice_agent; + onGatheringCandidateCB _on_gathering_candidate = nullptr; + private: mutable std::string _delete_rand_str; std::string _identifier; - EventPoller::Ptr _poller; - std::shared_ptr _dtls_transport; - std::shared_ptr _srtp_session_send; - std::shared_ptr _srtp_session_recv; - Ticker _ticker; + toolkit::EventPoller::Ptr _poller; + DtlsTransport::Ptr _dtls_transport; + SrtpSession::Ptr _srtp_session_send; + SrtpSession::Ptr _srtp_session_recv; + toolkit::Ticker _ticker; // 循环池 [AUTO-TRANSLATED:b7059f37] // Cycle pool - ResourcePool _packet_pool; + toolkit::ResourcePool _packet_pool; + //超时功能实现 + toolkit::Ticker _recv_ticker; + std::shared_ptr _check_timer; + std::function _on_start; + std::function _on_shutdown; + #ifdef ENABLE_SCTP RTC::SctpAssociationImp::Ptr _sctp; #endif @@ -246,7 +287,7 @@ public: struct WrappedMediaTrack { MediaTrack::Ptr track; - explicit WrappedMediaTrack(MediaTrack::Ptr ptr): track(ptr) {} + explicit WrappedMediaTrack(MediaTrack::Ptr ptr): track(std::move(ptr)) {} virtual ~WrappedMediaTrack() {} virtual void inputRtp(const char *buf, size_t len, uint64_t stamp_ms, RtpHeader *rtp) = 0; }; @@ -278,22 +319,26 @@ public: uint64_t getDuration() const; bool canSendRtp() const; bool canRecvRtp() const; + bool canSendRtp(const RtcMedia& media) const; + bool canRecvRtp(const RtcMedia& media) const; void onSendRtp(const RtpPacket::Ptr &rtp, bool flush, bool rtx = false); void createRtpChannel(const std::string &rid, uint32_t ssrc, MediaTrack &track); - void removeTuple(RTC::TransportTuple* tuple); - void safeShutdown(const SockException &ex); + void safeShutdown(const toolkit::SockException &ex); void setPreferredTcp(bool flag) override; void setLocalIp(std::string local_ip) override; void setIceCandidate(std::vector cands) override; protected: - void OnIceServerSelectedTuple(const RTC::IceServer *iceServer, RTC::TransportTuple *tuple) override; - WebRtcTransportImp(const EventPoller::Ptr &poller); + + // // ice相关的回调 /// [AUTO-TRANSLATED:30abf693] + // // ice related callbacks /// + + WebRtcTransportImp(const toolkit::EventPoller::Ptr &poller); void OnDtlsTransportApplicationDataReceived(const RTC::DtlsTransport *dtlsTransport, const uint8_t *data, size_t len) override; void onStartWebRTC() override; - void onSendSockData(Buffer::Ptr buf, bool flush = true, RTC::TransportTuple *tuple = nullptr) override; + void onSendSockData(toolkit::Buffer::Ptr buf, bool flush = true, const IceTransport::Pair::Ptr& pair = nullptr) override; void onCheckSdp(SdpType type, RtcSession &sdp) override; void onRtcConfigure(RtcConfigure &configure) const override; @@ -303,7 +348,7 @@ protected: void onBeforeEncryptRtcp(const char *buf, int &len, void *ctx) override {}; void onCreate() override; void onDestory() override; - void onShutdown(const SockException &ex) override; + void onShutdown(const toolkit::SockException &ex) override; virtual void onRecvRtp(MediaTrack &track, const std::string &rid, RtpPacket::Ptr rtp) {} void updateTicker(); float getLossRate(TrackType type); @@ -330,13 +375,17 @@ private: Ptr _self; // 检测超时的定时器 [AUTO-TRANSLATED:a58e1388] // Timeout detection timer - Timer::Ptr _timer; + toolkit::Timer::Ptr _timer; // 刷新计时器 [AUTO-TRANSLATED:61eb11e5] // Refresh timer - Ticker _alive_ticker; + toolkit::Ticker _alive_ticker; // pli rtcp计时器 [AUTO-TRANSLATED:a1a5fd18] // pli rtcp timer - Ticker _pli_ticker; + toolkit::Ticker _pli_ticker; + + toolkit::Ticker _rtcp_sr_send_ticker; + toolkit::Ticker _rtcp_rr_send_ticker; + // twcc rtcp发送上下文对象 [AUTO-TRANSLATED:aef6476a] // twcc rtcp send context object TwccContext _twcc_ctx; @@ -374,20 +423,20 @@ private: class WebRtcArgs : public std::enable_shared_from_this { public: virtual ~WebRtcArgs() = default; - virtual variant operator[](const std::string &key) const = 0; + virtual toolkit::variant operator[](const std::string &key) const = 0; }; using onCreateWebRtc = std::function; class WebRtcPluginManager { public: - using Plugin = std::function; - using Listener = std::function; + using Plugin = std::function; + using Listener = std::function; static WebRtcPluginManager &Instance(); void registerPlugin(const std::string &type, Plugin cb); void setListener(Listener cb); - void negotiateSdp(Session &sender, const std::string &type, const WebRtcArgs &args, const onCreateWebRtc &cb); + void negotiateSdp(toolkit::SocketHelper& sender, const std::string &type, const WebRtcArgs &args, const onCreateWebRtc &cb); private: WebRtcPluginManager() = default; @@ -398,4 +447,8 @@ private: std::unordered_map _map_creator; }; -}// namespace mediakit \ No newline at end of file +void translateIPFromEnv(std::vector &v); + +}// namespace mediakit + +#endif // ZLMEDIAKIT_WEBRTC_TRANSPORT_H diff --git a/webrtc/readme.md b/webrtc/readme.md index 3ac82dee..1f2c2c07 100644 --- a/webrtc/readme.md +++ b/webrtc/readme.md @@ -1,13 +1,6 @@ # 致谢与声明 本文件夹下部分文件提取自[MediaSoup](https://github.com/versatica/mediasoup) ,分别为: -- ice相关功能: - - IceServer.cpp - - IceServer.hpp - - StunPacket.cpp - - StunPacket.hpp - - Utils.hpp - - dtls相关功能: - DtlsTransport.cpp - DtlsTransport.hpp diff --git a/webrtc/webrtcSignal.txt b/webrtc/webrtcSignal.txt new file mode 100644 index 00000000..f09bd176 --- /dev/null +++ b/webrtc/webrtcSignal.txt @@ -0,0 +1,132 @@ +webrtc websocket 信令 + +# register 注册 + +``` json +#client/peer --> server +{ + "class" : "request", + "method" : "register", + "transaction_id" : "HFaq5Jp2agKfDjizOT5jGpiPtOQ8yays" + "room_id" : "room_1", +} +#server --> client/peer +#success +#支持turn +{ + "class" : "accept", + "method" : "register", + "transaction_id" : "HFaq5Jp2agKfDjizOT5jGpiPtOQ8yays" + "room_id" : "room_1", + "ice_servers" : [ { + "pwd" : "ZLMediaKit", + "ufrag" : "ZLMediaKit", + "url" : "turn:10.9.120.61:3478?transport=udp" + } + ], +} + +#不支持turn +{ + "class" : "accept", + "method" : "register", + "transaction_id" : "HFaq5Jp2agKfDjizOT5jGpiPtOQ8yays" + "room_id" : "room_1", + "ice_servers" : [ { + "pwd" : "ZLMediaKit", + "ufrag" : "ZLMediaKit", + "url" : "stun:10.9.120.61:3478?transport=udp" + } + ], +} +``` + +#fail +``` json +{ + "class" : "reject", + "method" : "register", + "transaction_id" : "2DiOjTulA4Glp9Si7yHdQypibAn2LPaX" + "reason" : "alreadly register", + "room_id" : "room_1", +} +``` + +# unregister 注销 +# client --> server +```json +{ + "class" : "request", + "method" : "unregister", + "transaction_id" : "0Xbgr86OIacWvjJIc03EsxH3QIF1ou8m" + "room_id" : "room_1", +} +``` +# server --> client +# success +``` json +{ + "class" : "accept", + "method" : "unregister", + "room_id" : "room1", + "transaction_id" : "0Xbgr86OIacWvjJIc03EsxH3QIF1ou8m" +} + +``` + +# 呼叫 +# client --> server,server -透传-> peer +{ + "class" : "request", + "method" : "call", + "transaction_id" : "qUpN8C49bGiyOHk6WNanAFq2viSkk6HC", + "guest_id" : "guest1_EDuVWIxLUMlDDKDa", + "room_id" : "room_1", + "type" : "play", + "vhost" : "__defaultVhost__", + "app" : "live", + "stream" : "test", + "sdp" : "v=0\r\no=- 7040255305116218076 1 IN IP4 0.0.0.0\r\ns=-\r\nt=0 0\r\na=group:BUNDLE 0 1\r\na=extmap-allow-mixed\r\na=msid-semantic: WMS\r\nm=video 9 UDP/TLS/RTP/SAVPF 102 124 123 102 124 123 35 98 100 96\r\nc=IN IP4 0.0.0.0\r\na=ice-ufrag:rBIAAR9AH0A=_1\r\na=ice-pwd:V1WhKKOK9jrhmLPmZemhcO5h\r\na=ice-options:trickle\r\na=fingerprint:sha-256 B4:51:C0:D2:0E:60:70:C2:CD:40:3A:8E:33:EB:FC:67:F6:29:72:89:AC:23:48:90:A0:D7:C0:07:44:7B:F1:79\r\na=setup:active\r\na=mid:0\r\na=extmap:2 http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time\r\na=extmap:3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01\r\na=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id\r\na=extmap:6 urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id\r\na=extmap:7 http://www.webrtc.org/experiments/rtp-hdrext/video-timing\r\na=extmap:8 http://www.webrtc.org/experiments/rtp-hdrext/color-space\r\na=extmap:10 http://tools.ietf.org/html/draft-ietf-avtext-framemarking-07\r\na=extmap:11 http://www.webrtc.org/experiments/rtp-hdrext/video-content-type\r\na=extmap:12 http://www.webrtc.org/experiments/rtp-hdrext/playout-delay\r\na=extmap:14 urn:ietf:params:rtp-hdrext:toffset\r\na=recvonly\r\na=rtcp-mux\r\na=rtcp-rsize\r\na=rtpmap:102 H264/90000\r\na=rtcp-fb:102 ccm fir\r\na=rtcp-fb:102 goog-remb\r\na=rtcp-fb:102 nack\r\na=rtcp-fb:102 nack pli\r\na=rtcp-fb:102 transport-cc\r\na=fmtp:102 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f\r\na=rtpmap:124 H264/90000\r\na=rtcp-fb:124 ccm fir\r\na=rtcp-fb:124 goog-remb\r\na=rtcp-fb:124 nack\r\na=rtcp-fb:124 nack pli\r\na=rtcp-fb:124 transport-cc\r\na=fmtp:124 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=4d001f\r\na=rtpmap:123 H264/90000\r\na=rtcp-fb:123 ccm fir\r\na=rtcp-fb:123 goog-remb\r\na=rtcp-fb:123 nack\r\na=rtcp-fb:123 nack pli\r\na=rtcp-fb:123 transport-cc\r\na=fmtp:123 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=64001f\r\na=rtpmap:102 H264/90000\r\na=rtcp-fb:102 ccm fir\r\na=rtcp-fb:102 goog-remb\r\na=rtcp-fb:102 nack\r\na=rtcp-fb:102 nack pli\r\na=rtcp-fb:102 transport-cc\r\na=fmtp:102 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f\r\na=rtpmap:124 H264/90000\r\na=rtcp-fb:124 ccm fir\r\na=rtcp-fb:124 goog-remb\r\na=rtcp-fb:124 nack\r\na=rtcp-fb:124 nack pli\r\na=rtcp-fb:124 transport-cc\r\na=fmtp:124 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=4d001f\r\na=rtpmap:123 H264/90000\r\na=rtcp-fb:123 ccm fir\r\na=rtcp-fb:123 goog-remb\r\na=rtcp-fb:123 nack\r\na=rtcp-fb:123 nack pli\r\na=rtcp-fb:123 transport-cc\r\na=fmtp:123 level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=64001f\r\na=rtpmap:35 AV1/90000\r\na=rtcp-fb:35 ccm fir\r\na=rtcp-fb:35 goog-remb\r\na=rtcp-fb:35 nack\r\na=rtcp-fb:35 nack pli\r\na=rtcp-fb:35 transport-cc\r\na=rtpmap:98 VP9/90000\r\na=rtcp-fb:98 ccm fir\r\na=rtcp-fb:98 goog-remb\r\na=rtcp-fb:98 nack\r\na=rtcp-fb:98 nack pli\r\na=rtcp-fb:98 transport-cc\r\na=fmtp:98 profile-id==0\r\na=rtpmap:100 VP9/90000\r\na=rtcp-fb:100 ccm fir\r\na=rtcp-fb:100 goog-remb\r\na=rtcp-fb:100 nack\r\na=rtcp-fb:100 nack pli\r\na=rtcp-fb:100 transport-cc\r\na=fmtp:100 profile-id==2\r\na=rtpmap:96 VP8/90000\r\na=rtcp-fb:96 ccm fir\r\na=rtcp-fb:96 goog-remb\r\na=rtcp-fb:96 nack\r\na=rtcp-fb:96 nack pli\r\na=rtcp-fb:96 transport-cc\r\nm=audio 9 UDP/TLS/RTP/SAVPF 0 8 111 96\r\nc=IN IP4 0.0.0.0\r\na=ice-ufrag:rBIAAR9AH0A=_1\r\na=ice-pwd:V1WhKKOK9jrhmLPmZemhcO5h\r\na=ice-options:trickle\r\na=fingerprint:sha-256 B4:51:C0:D2:0E:60:70:C2:CD:40:3A:8E:33:EB:FC:67:F6:29:72:89:AC:23:48:90:A0:D7:C0:07:44:7B:F1:79\r\na=setup:active\r\na=mid:1\r\na=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level\r\na=extmap:2 http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time\r\na=extmap:3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01\r\na=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id\r\na=extmap:6 urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id\r\na=extmap:9 urn:ietf:params:rtp-hdrext:csrc-audio-level\r\na=recvonly\r\na=rtcp-mux\r\na=rtcp-rsize\r\na=rtpmap:0 PCMU/8000\r\na=rtcp-fb:0 goog-remb\r\na=rtcp-fb:0 transport-cc\r\na=rtpmap:8 PCMA/8000\r\na=rtcp-fb:8 goog-remb\r\na=rtcp-fb:8 transport-cc\r\na=rtpmap:111 opus/48000\r\na=rtcp-fb:111 goog-remb\r\na=rtcp-fb:111 transport-cc\r\na=rtpmap:96 mpeg4-generic/48000\r\na=rtcp-fb:96 goog-remb\r\na=rtcp-fb:96 transport-cc\r\n" +} + +#peer->server, server -透传->client +{ + "class" : "accept", + "method" : "call", + "transaction_id" : "qUpN8C49bGiyOHk6WNanAFq2viSkk6HC", + "guest_id" : "guest1_EDuVWIxLUMlDDKDa", + "room_id" : "room1", + "vhost" : "__defaultVhost__", + "app" : "live", + "stream" : "test", + "type" : "play", + "sdp" : "v=0\r\no=- 7040255305116218076 1 IN IP4 10.9.120.61\r\ns=-\r\nt=0 0\r\na=group:BUNDLE 0 1\r\na=extmap-allow-mixed\r\na=msid-semantic: WMS\r\na=ice-lite\r\nm=video 28000 UDP/TLS/RTP/SAVPF 102\r\nc=IN IP4 10.9.120.61\r\na=rtcp:28000 IN IP4 10.9.120.61\r\na=ice-ufrag:rBIAAW1gbWA=_1\r\na=ice-pwd:NmPNJgMbz9z2kH3g97yZFbCn\r\na=ice-options:trickle\r\na=fingerprint:sha-256 B4:51:C0:D2:0E:60:70:C2:CD:40:3A:8E:33:EB:FC:67:F6:29:72:89:AC:23:48:90:A0:D7:C0:07:44:7B:F1:79\r\na=setup:passive\r\na=mid:0\r\na=ice-lite\r\na=extmap:2 http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time\r\na=extmap:3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01\r\na=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id\r\na=extmap:6 urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id\r\na=extmap:7 http://www.webrtc.org/experiments/rtp-hdrext/video-timing\r\na=extmap:8 http://www.webrtc.org/experiments/rtp-hdrext/color-space\r\na=extmap:10 http://tools.ietf.org/html/draft-ietf-avtext-framemarking-07\r\na=extmap:11 http://www.webrtc.org/experiments/rtp-hdrext/video-content-type\r\na=extmap:12 http://www.webrtc.org/experiments/rtp-hdrext/playout-delay\r\na=extmap:14 urn:ietf:params:rtp-hdrext:toffset\r\na=sendonly\r\na=rtcp-mux\r\na=rtpmap:102 H264/90000\r\na=rtcp-fb:102 ccm fir\r\na=rtcp-fb:102 goog-remb\r\na=rtcp-fb:102 nack\r\na=rtcp-fb:102 nack pli\r\na=rtcp-fb:102 transport-cc\r\na=fmtp:102 level-asymmetry-allowed=1;packetization-mode=0;profile-level-id=42001f\r\na=msid:zlmediakit-mslabel zlmediakit-label-0\r\na=ssrc:1 cname:zlmediakit-rtp\r\na=ssrc:1 msid:zlmediakit-mslabel zlmediakit-label-0\r\na=ssrc:1 mslabel:zlmediakit-mslabel\r\na=ssrc:1 label:zlmediakit-label-0\r\nm=audio 28000 UDP/TLS/RTP/SAVPF 0\r\nc=IN IP4 10.9.120.61\r\na=rtcp:28000 IN IP4 10.9.120.61\r\na=ice-ufrag:rBIAAW1gbWA=_1\r\na=ice-pwd:NmPNJgMbz9z2kH3g97yZFbCn\r\na=ice-options:trickle\r\na=fingerprint:sha-256 B4:51:C0:D2:0E:60:70:C2:CD:40:3A:8E:33:EB:FC:67:F6:29:72:89:AC:23:48:90:A0:D7:C0:07:44:7B:F1:79\r\na=setup:passive\r\na=mid:1\r\na=ice-lite\r\na=extmap:1 urn:ietf:params:rtp-hdrext:ssrc-audio-level\r\na=extmap:2 http://www.webrtc.org/experiments/rtp-hdrext/abs-send-time\r\na=extmap:3 http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01\r\na=extmap:5 urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id\r\na=extmap:6 urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id\r\na=extmap:9 urn:ietf:params:rtp-hdrext:csrc-audio-level\r\na=inactive\r\na=rtcp-mux\r\na=rtpmap:0 PCMU/8000/1\r\na=rtcp-fb:0 goog-remb\r\na=rtcp-fb:0 transport-cc\r\na=msid:zlmediakit-mslabel zlmediakit-label-1\r\na=ssrc:2 cname:zlmediakit-rtp\r\na=ssrc:2 msid:zlmediakit-mslabel zlmediakit-label-1\r\na=ssrc:2 mslabel:zlmediakit-mslabel\r\na=ssrc:2 label:zlmediakit-label-1\r\n" +} + +# candidate +```peer--> server -透传-> peer +{ + "class" : "indication", + "method" : "candidate", + "transaction_id" : "7oEa2vcYvps7aZ1g9UGIPoFf5PrTl2N9", + "guest_id" : "guest1_n9WyhNMR42EvkOvE", + "room_id" : "room_1", + "candidate" : "7e0de214 1 udp 2113955071 192.168.1.1 46411 typ host", + "ufrag" : "rBIAAW1gbWA=_1" + "pwd" : "gDNJZM0uVLlnNnthaE41KXOp", +} + +# bye + +``` +{ + "class" : "indication", + "method" : "bye", + "transaction_id" : "86RdplPz21Ow9DwR1gvXjsAdmh30TAf3" + "guest_id" : "guest1_n9WyhNMR42EvkOvE", + "reason" : "peer unregister", + "room_id" : "room_1", +} +``` + + diff --git a/www/logo.ico b/www/logo.ico new file mode 100644 index 00000000..132a1ec5 Binary files /dev/null and b/www/logo.ico differ diff --git a/www/webrtc/index.html b/www/webrtc/index.html index 704f4790..2e8bedc6 100644 --- a/www/webrtc/index.html +++ b/www/webrtc/index.html @@ -66,8 +66,9 @@

- + echo + talk push play