diff --git a/.depot/workflows/android-instrumented.yml b/.depot/workflows/android-instrumented.yml index 7544629b8..d54acaa43 100644 --- a/.depot/workflows/android-instrumented.yml +++ b/.depot/workflows/android-instrumented.yml @@ -61,9 +61,12 @@ jobs: # rebuilt when those sources change (mirrors the Namespace libcore cache strategy). - name: libcore status run: | - find buildScript libcore/*.sh buildScript/lib/core/get_source_env.sh \ - | sort | xargs cat | sha1sum | awk '{print $1}' > libcore_status - git ls-files libcore | sort | xargs cat | sha1sum | awk '{print $1}' >> libcore_status + # git ls-files -s emits for every tracked + # entry (content- and path-aware) without touching the filesystem, so it is + # deterministic and safe even for dangling symlinks (e.g. buildScript/nkmr). + git ls-files -s -- buildScript libcore \ + | sha256sum \ + | awk '{print $1}' > libcore_status cat libcore_status - name: libcore cache diff --git a/.depot/workflows/build-apk.yml b/.depot/workflows/build-apk.yml new file mode 100644 index 000000000..f7ffc47fa --- /dev/null +++ b/.depot/workflows/build-apk.yml @@ -0,0 +1,148 @@ +name: Build APK +on: + push: + branches: + - '**' + workflow_dispatch: +permissions: + contents: read +env: + GO_VERSION: '1.26.4' + NDK_VERSION: '25.0.8775105' + MIERU_VERSION: v3.34.0 + HYSTERIA_VERSION: v2.9.2 + MDVPN_REF: android-vpnservice-protect-hook + MDVPN_COMMIT: d481d72d4b86783a87d536c214d2c68cc4e9320e + NAIVE_VERSION: v149.0.7827.114-1 +jobs: + build-apk: + name: Build OSS Debug APK + runs-on: depot-ubuntu-24.04-8 + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup Java + uses: actions/setup-java@v5 + with: + distribution: 'temurin' + java-version: '17' + + - name: Setup Android SDK + uses: android-actions/setup-android@v3 + + - name: Install SDK platform + NDK + run: sdkmanager --install "platforms;android-35" "build-tools;35.0.0" "ndk;${NDK_VERSION}" + + - name: local.properties + run: | + echo "sdk.dir=${ANDROID_HOME}" > local.properties + echo "ndk.dir=${ANDROID_HOME}/ndk/${NDK_VERSION}" >> local.properties + + - name: libcore status + run: | + # git ls-files -s emits for every tracked + # entry (content- and path-aware) without touching the filesystem, so it is + # deterministic and safe even for dangling symlinks (e.g. buildScript/nkmr). + git ls-files -s -- buildScript libcore \ + | sha256sum \ + | awk '{print $1}' > libcore_status + + - name: libcore cache + id: libcore-cache + uses: actions/cache@v4 + with: + path: app/libs/libcore.aar + key: depot-libcore-${{ env.GO_VERSION }}-${{ env.NDK_VERSION }}-${{ hashFiles('libcore_status') }} + + - name: sidecars status + run: | + git ls-files -s -- \ + buildScript/lib/mieru.sh \ + buildScript/lib/hysteria2.sh \ + buildScript/lib/masterdnsvpn.sh \ + buildScript/lib/naive.sh \ + buildScript/init/env.sh \ + buildScript/init/env_ndk.sh \ + | sha256sum \ + | awk '{print $1}' > sidecars_status + + - name: sidecars cache + id: sidecars-cache + uses: actions/cache@v4 + with: + path: app/executableSo + key: depot-sidecars-${{ env.GO_VERSION }}-${{ env.NDK_VERSION }}-${{ env.MIERU_VERSION }}-${{ env.HYSTERIA_VERSION }}-${{ env.MDVPN_COMMIT }}-${{ env.NAIVE_VERSION }}-${{ hashFiles('sidecars_status') }} + + - name: Install Go + if: steps.libcore-cache.outputs.cache-hit != 'true' || steps.sidecars-cache.outputs.cache-hit != 'true' + uses: actions/setup-go@v6 + with: + go-version: ${{ env.GO_VERSION }} + + - name: Build libcore (Go + gomobile) + if: steps.libcore-cache.outputs.cache-hit != 'true' + run: ./run lib core + + - name: Build sidecars + if: steps.sidecars-cache.outputs.cache-hit != 'true' + run: | + ./run lib mieru + ./run lib hysteria2 + ./run lib masterdnsvpn + ./run lib naive + + - name: Verify sidecar artifacts + run: | + for abi in arm64-v8a armeabi-v7a x86 x86_64; do + for so in libmieru.so libhysteria2.so libmasterdnsvpn.so libnaive.so; do + if [ ! -f "app/executableSo/$abi/$so" ]; then + echo "Error: missing app/executableSo/$abi/$so" >&2 + exit 1 + fi + done + done + + - name: Gradle cache + uses: actions/cache@v4 + with: + path: ~/.gradle + key: depot-gradle-oss-${{ hashFiles('**/*.gradle.kts') }} + + - name: Build APK + env: + BUILD_PLUGIN: none + KEYSTORE_B64: ${{ secrets.KEYSTORE_B64 }} + KEYSTORE_PASS: ${{ secrets.KEYSTORE_PASS }} + ALIAS_NAME: ${{ secrets.ALIAS_NAME }} + ALIAS_PASS: ${{ secrets.ALIAS_PASS }} + run: | + set -euo pipefail + if [ -z "${KEYSTORE_B64}" ]; then + echo "Error: signing keystore is not configured" >&2 + exit 1 + fi + if [ -z "${KEYSTORE_PASS}" ] || [ -z "${ALIAS_NAME}" ] || [ -z "${ALIAS_PASS}" ]; then + echo "Error: signing configuration is incomplete" >&2 + exit 1 + fi + ./run init action gradle + echo "${KEYSTORE_B64}" | base64 -d > release.keystore + ./gradlew app:assembleOssDebug + python3 - <<'PY' > apk_file + from pathlib import Path + matches = sorted(Path('app/build/outputs/renamed_apks').rglob('*arm64-v8a*.apk')) + if not matches: + raise SystemExit('Error: no arm64-v8a APK found under app/build/outputs/renamed_apks') + print(matches[0]) + PY + APK_FILE=$(cat apk_file) + echo "APK_FILE=$APK_FILE" >> "$GITHUB_ENV" + echo "Built APK: $APK_FILE" + + - name: Upload APK + uses: actions/upload-artifact@v4 + with: + name: NekoBox-debug-arm64-v8a-apk + path: ${{ env.APK_FILE }} + if-no-files-found: error diff --git a/.depot/workflows/lint.yml b/.depot/workflows/lint.yml index 6b4b7fce2..daa7917a8 100644 --- a/.depot/workflows/lint.yml +++ b/.depot/workflows/lint.yml @@ -47,10 +47,12 @@ jobs: - name: libcore status run: | - { find buildScript libcore -type f -print0; \ - printf '%s\0' buildScript/lib/core/get_source_env.sh; } \ - | sort -z | xargs -0 cat | sha1sum | awk '{print $1}' > libcore_status - git ls-files libcore | sort | xargs cat | sha1sum | awk '{print $1}' >> libcore_status + # git ls-files -s emits for every tracked + # entry (content- and path-aware) without touching the filesystem, so it is + # deterministic and safe even for dangling symlinks (e.g. buildScript/nkmr). + git ls-files -s -- buildScript libcore \ + | sha256sum \ + | awk '{print $1}' > libcore_status - name: libcore cache id: libcore-cache uses: actions/cache@v4 diff --git a/.depot/workflows/unit-tests.yml b/.depot/workflows/unit-tests.yml index acaf726e1..45f10e467 100644 --- a/.depot/workflows/unit-tests.yml +++ b/.depot/workflows/unit-tests.yml @@ -49,10 +49,12 @@ jobs: - name: libcore status run: | - { find buildScript libcore -type f -print0; \ - printf '%s\0' buildScript/lib/core/get_source_env.sh; } \ - | sort -z | xargs -0 cat | sha1sum | awk '{print $1}' > libcore_status - git ls-files libcore | sort | xargs cat | sha1sum | awk '{print $1}' >> libcore_status + # git ls-files -s emits for every tracked + # entry (content- and path-aware) without touching the filesystem, so it is + # deterministic and safe even for dangling symlinks (e.g. buildScript/nkmr). + git ls-files -s -- buildScript libcore \ + | sha256sum \ + | awk '{print $1}' > libcore_status - name: libcore cache id: libcore-cache uses: actions/cache@v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a2f0fa9c..f64a3a61f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,11 +1,6 @@ name: CI on: - push: - tags-ignore: - - 'v*' - branches: - - '*' - pull_request: + workflow_dispatch: env: MIERU_VERSION: v3.34.0 HYSTERIA_VERSION: v2.9.2 diff --git a/app/lint-baseline.xml b/app/lint-baseline.xml index d16d97cbd..943ad3f13 100644 --- a/app/lint-baseline.xml +++ b/app/lint-baseline.xml @@ -1,27 +1,7 @@ - - - - - - = Build.VERSION_CODES.TIRAMISU && + ContextCompat.checkSelfPermission(serviceContext, POST_NOTIFICATIONS) != + PackageManager.PERMISSION_GRANTED + ) { + return@useBuilder + } + try { + notificationManager.notify(notificationId, it.build()) + } catch (e: SecurityException) { + Logs.w("service notification update skipped", e) + } } fun destroy() { diff --git a/app/src/main/java/io/nekohasekai/sagernet/bg/SubscriptionUpdater.kt b/app/src/main/java/io/nekohasekai/sagernet/bg/SubscriptionUpdater.kt index 33e30c436..3e33d9ea9 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/bg/SubscriptionUpdater.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/bg/SubscriptionUpdater.kt @@ -1,8 +1,12 @@ package io.nekohasekai.sagernet.bg +import android.Manifest.permission.POST_NOTIFICATIONS import android.content.Context +import android.content.pm.PackageManager +import android.os.Build import androidx.core.app.NotificationCompat import androidx.core.app.NotificationManagerCompat +import androidx.core.content.ContextCompat import androidx.work.CoroutineWorker import androidx.work.ExistingPeriodicWorkPolicy.UPDATE import androidx.work.PeriodicWorkRequest @@ -16,6 +20,34 @@ import io.nekohasekai.sagernet.ktx.Logs import io.nekohasekai.sagernet.ktx.app import java.util.concurrent.TimeUnit +internal data class SubscriptionWorkSchedule( + val intervalMinutes: Long, + val initialDelaySeconds: Long, +) + +internal data class SubscriptionScheduleInput( + val lastUpdated: Int, + val autoUpdateDelay: Int, +) + +internal fun computeSubscriptionWorkSchedule( + subscriptions: List, + nowSeconds: Long = System.currentTimeMillis() / 1000L, +): SubscriptionWorkSchedule? { + if (subscriptions.isEmpty()) return null + + val intervalMinutes = subscriptions + .minOf { it.autoUpdateDelay.toLong() } + .coerceAtLeast(15L) + val initialDelaySeconds = subscriptions.minOf { subscription -> + val dueAt = subscription.lastUpdated.toLong() + + subscription.autoUpdateDelay.toLong().coerceAtLeast(15L) * 60L + dueAt - nowSeconds + }.coerceAtLeast(0L) + + return SubscriptionWorkSchedule(intervalMinutes, initialDelaySeconds) +} + object SubscriptionUpdater { private const val WORK_NAME = "SubscriptionUpdater" @@ -28,22 +60,24 @@ object SubscriptionUpdater { .filter { (_, sub) -> sub.autoUpdate } if (subscriptions.isEmpty()) return - // PeriodicWorkRequest.MIN_PERIODIC_INTERVAL_MILLIS - var minDelay = - subscriptions.minByOrNull { (_, sub) -> sub.autoUpdateDelay }!!.second.autoUpdateDelay.toLong() - val now = System.currentTimeMillis() / 1000L - var minInitDelay = - subscriptions.minOf { (_, sub) -> now - sub.lastUpdated - (minDelay * 60) } - if (minDelay < 15) minDelay = 15 - if (minInitDelay > 60) minInitDelay = 60 + val schedule = computeSubscriptionWorkSchedule( + subscriptions.map { (_, sub) -> + SubscriptionScheduleInput( + lastUpdated = sub.lastUpdated ?: 0, + autoUpdateDelay = sub.autoUpdateDelay ?: 1440, + ) + }, + ) ?: return // main process RemoteWorkManager.getInstance(app).enqueueUniquePeriodicWork( WORK_NAME, UPDATE, - PeriodicWorkRequest.Builder(UpdateTask::class.java, minDelay, TimeUnit.MINUTES) + PeriodicWorkRequest.Builder(UpdateTask::class.java, schedule.intervalMinutes, TimeUnit.MINUTES) .apply { - if (minInitDelay > 0) setInitialDelay(minInitDelay, TimeUnit.SECONDS) + if (schedule.initialDelaySeconds > 0) { + setInitialDelay(schedule.initialDelaySeconds, TimeUnit.SECONDS) + } } .build(), ) @@ -54,9 +88,9 @@ object SubscriptionUpdater { params: WorkerParameters, ) : CoroutineWorker(appContext, params) { - val nm = NotificationManagerCompat.from(applicationContext) + private val nm = NotificationManagerCompat.from(applicationContext) - val notification = NotificationCompat.Builder(applicationContext, "service-subscription") + private val notification = NotificationCompat.Builder(applicationContext, "service-subscription") .setWhen(0) .setTicker(applicationContext.getString(R.string.forward_success)) .setContentTitle(applicationContext.getString(R.string.subscription_update)) @@ -73,9 +107,15 @@ object SubscriptionUpdater { subscriptions = subscriptions.filter { (_, sub) -> !sub.updateWhenConnectedOnly } } + var attempted = false + var failed = false if (subscriptions.isNotEmpty()) { + val nowSeconds = System.currentTimeMillis() / 1000L for ((profile, subscription) in subscriptions) { - if (((System.currentTimeMillis() / 1000).toInt() - subscription.lastUpdated) < subscription.autoUpdateDelay * 60) { + val lastUpdated = (subscription.lastUpdated ?: 0).toLong() + val delaySeconds = + (subscription.autoUpdateDelay ?: 1440).toLong().coerceAtLeast(15L) * 60L + if (nowSeconds - lastUpdated < delaySeconds) { Logs.d("work: not updating " + profile.displayName()) continue } @@ -87,15 +127,38 @@ object SubscriptionUpdater { profile.displayName(), ), ) - nm.notify(2, notification.build()) + notifyProgress() - GroupUpdater.executeUpdate(profile, false) + attempted = true + if (!GroupUpdater.executeUpdate(profile, false)) { + failed = true + } } } - nm.cancel(2) + try { + nm.cancel(2) + } catch (e: SecurityException) { + Logs.w("subscription notification cancel skipped", e) + } - return Result.success() + return if (attempted && failed) Result.retry() else Result.success() + } + + private fun notifyProgress() { + if (!nm.areNotificationsEnabled()) return + if ( + Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU && + ContextCompat.checkSelfPermission(applicationContext, POST_NOTIFICATIONS) != + PackageManager.PERMISSION_GRANTED + ) { + return + } + try { + nm.notify(2, notification.build()) + } catch (e: SecurityException) { + Logs.w("subscription notification update skipped", e) + } } } } diff --git a/app/src/main/java/io/nekohasekai/sagernet/group/RawUpdater.kt b/app/src/main/java/io/nekohasekai/sagernet/group/RawUpdater.kt index 7ab618d84..2bc2861ff 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/group/RawUpdater.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/group/RawUpdater.kt @@ -75,7 +75,7 @@ object RawUpdater : GroupUpdater() { setURL(subscription.link) setUserAgent(subscription.customUserAgent.takeIf { it.isNotBlank() } ?: USER_AGENT) }.execute() - proxies = parseRaw(Util.getStringBox(response.contentString)) + proxies = parseRaw(Util.getStringBox(response.getContentStringLimited(10L * 1024 * 1024))) ?: error(app.getString(R.string.no_proxies_found)) subscription.subscriptionUserinfo = @@ -234,25 +234,33 @@ object RawUpdater : GroupUpdater() { userOrder++ } - SagerDatabase.proxyDao.insert(toInsert) - Logs.d("Inserted profiles: ${toInsert.size}") - - SagerDatabase.proxyDao.updateProxy(toUpdate).also { - Logs.d("Updated profiles: $it") - } - - SagerDatabase.proxyDao.deleteProxy(toDelete).also { - Logs.d("Deleted profiles: $it") - } + var updatedCount = 0 + var deletedCount = 0 + SagerDatabase.instance.runInTransaction { + if (toInsert.isNotEmpty()) { + SagerDatabase.proxyDao.insert(toInsert) + } + if (toUpdate.isNotEmpty()) { + updatedCount = SagerDatabase.proxyDao.updateProxy(toUpdate) + } + if (toDelete.isNotEmpty()) { + deletedCount = SagerDatabase.proxyDao.deleteProxy(toDelete) + } - val existCount = SagerDatabase.proxyDao.countByGroup(proxyGroup.id).toInt() + val existCount = SagerDatabase.proxyDao.countByGroup(proxyGroup.id).toInt() + if (existCount != proxies.size) { + val message = "Exist profiles: $existCount, new profiles: ${proxies.size}" + Logs.e(message) + error(message) + } - if (existCount != proxies.size) { - Logs.e("Exist profiles: $existCount, new profiles: ${proxies.size}") + subscription.lastUpdated = (System.currentTimeMillis() / 1000).toInt() + SagerDatabase.groupDao.updateGroup(proxyGroup) } - subscription.lastUpdated = (System.currentTimeMillis() / 1000).toInt() - SagerDatabase.groupDao.updateGroup(proxyGroup) + Logs.d("Inserted profiles: ${toInsert.size}") + Logs.d("Updated profiles: $updatedCount") + Logs.d("Deleted profiles: $deletedCount") finishUpdate(proxyGroup) userInterface?.onUpdateSuccess( diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/AboutFragment.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/AboutFragment.kt index 8c88d77d6..5fcf86127 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/ui/AboutFragment.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/AboutFragment.kt @@ -191,7 +191,7 @@ class AboutFragment : ToolbarFragment(R.layout.layout_about) { setURL("https://api.github.com/repos/hawkff/NekoBoxForAndroid/releases/latest") } }.execute() - val release = JSONObject(Util.getStringBox(response.contentString)) + val release = JSONObject(Util.getStringBox(response.getContentStringLimited(10L * 1024 * 1024))) val releaseName = release.getString("name") val releaseUrl = release.getString("html_url") var haveUpdate = releaseName.isNotBlank() diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/AssetsActivity.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/AssetsActivity.kt index 1a287378f..702f1b366 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/ui/AssetsActivity.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/AssetsActivity.kt @@ -2,6 +2,7 @@ package io.nekohasekai.sagernet.ui import android.os.Bundle import android.provider.OpenableColumns +import android.system.Os import android.text.format.DateFormat import android.view.Menu import android.view.MenuItem @@ -26,6 +27,9 @@ import java.io.FileWriter import java.util.* import java.util.concurrent.atomic.AtomicInteger +private const val MAX_HTTP_JSON_BYTES = 10L * 1024 * 1024 +private const val MAX_RULE_ASSET_BYTES = 256L * 1024 * 1024 + /** * Reduce an untrusted file name (e.g. a content provider's DISPLAY_NAME) to a safe basename * for an importable asset, or null if invalid. Strips any directory component and rejects @@ -298,6 +302,14 @@ class AssetsActivity : ThemedActivity() { } } + private fun replaceAssetFile(tempFile: File, targetFile: File) { + try { + Os.rename(tempFile.absolutePath, targetFile.absolutePath) + } catch (e: Exception) { + error("Unable to save the route asset: ${e.readableMessage}") + } + } + private val rulesProviders = listOf( RuleAssetsProvider( "SagerNet/sing-geoip", @@ -337,7 +349,7 @@ class AssetsActivity : ThemedActivity() { setURL("https://api.github.com/repos/$repo/releases/latest") }.execute() - val release = JSONObject(Util.getStringBox(response.contentString)) + val release = JSONObject(Util.getStringBox(response.getContentStringLimited(MAX_HTTP_JSON_BYTES))) val tagName = release.optString("tag_name") if (tagName == localVersion) { @@ -348,9 +360,13 @@ class AssetsActivity : ThemedActivity() { } val releaseAssets = release.getJSONArray("assets").filterIsInstance() - val assetToDownload = releaseAssets.find { it.getStr("name") == fileName } - ?: error("File $fileName not found in release ${release["url"]}") + val assetToDownload = releaseAssets.find { + val assetName = it.getStr("name") + assetName == fileName || assetName == "$fileName.xz" + } ?: error("File $fileName not found in release ${release["url"]}") + val downloadName = assetToDownload.getStr("name") ?: error("Release asset is missing a name") val browserDownloadUrl = assetToDownload.getStr("browser_download_url") + ?: error("Release asset $downloadName is missing a download URL") response = client.newRequest().apply { setURL(browserDownloadUrl) @@ -359,16 +375,27 @@ class AssetsActivity : ThemedActivity() { val cacheFile = File(file.parentFile, fileName + ".tmp") cacheFile.parentFile?.mkdirs() - response.writeTo(cacheFile.canonicalPath) + try { + response.writeToLimited(cacheFile.canonicalPath, MAX_RULE_ASSET_BYTES) - if (fileName.endsWith(".xz")) { - Libcore.unxz(cacheFile.absolutePath, file.absolutePath) - cacheFile.delete() - } else { - cacheFile.renameTo(file) - } + if (downloadName.endsWith(".xz")) { + val unpackedFile = File(file.parentFile, file.nameWithoutExtension + ".unxz.tmp") + try { + // Libcore.unxz enforces the same 256 MB cap (defaultUnxzFileLimit) + // and fails before writing if exceeded, so no extra size check here. + Libcore.unxz(cacheFile.absolutePath, unpackedFile.absolutePath) + replaceAssetFile(unpackedFile, file) + } finally { + if (unpackedFile.exists()) unpackedFile.delete() + } + } else { + replaceAssetFile(cacheFile, file) + } - versionFile.writeText(tagName) + versionFile.writeText(tagName) + } finally { + if (cacheFile.exists()) cacheFile.delete() + } adapter.reloadAssets() @@ -404,11 +431,15 @@ class AssetsActivity : ThemedActivity() { }.execute() val cacheFile = File(file.parentFile, fileName + ".tmp") cacheFile.parentFile?.mkdirs() - response.writeTo(cacheFile.canonicalPath) - cacheFile.renameTo(file) - - val currentDate = java.text.SimpleDateFormat("yyyyMMdd").format(java.util.Date()) - versionFile.writeText(currentDate) + try { + response.writeToLimited(cacheFile.canonicalPath, MAX_RULE_ASSET_BYTES) + replaceAssetFile(cacheFile, file) + + val currentDate = java.text.SimpleDateFormat("yyyyMMdd", Locale.US).format(Date()) + versionFile.writeText(currentDate) + } finally { + if (cacheFile.exists()) cacheFile.delete() + } adapter.reloadAssets() onMainDispatcher { diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFormatV2.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFormatV2.kt new file mode 100644 index 000000000..1b38ca342 --- /dev/null +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFormatV2.kt @@ -0,0 +1,167 @@ +package io.nekohasekai.sagernet.ui + +import io.nekohasekai.sagernet.database.ProxyEntity +import io.nekohasekai.sagernet.database.ProxyGroup +import io.nekohasekai.sagernet.database.RuleEntity +import io.nekohasekai.sagernet.database.preference.KeyValuePair +import io.nekohasekai.sagernet.fmt.KryoConverters +import org.json.JSONArray +import org.json.JSONObject +import java.util.Base64 + +internal object BackupFormatV2 { + const val VERSION = 2 + + fun encodeProfiles(profiles: List): JSONArray = JSONArray().apply { + profiles.forEach { put(encodeProfile(it)) } + } + + fun decodeProfiles(array: JSONArray): List = array.mapObjects(::decodeProfile) + + fun encodeGroups(groups: List): JSONArray = JSONArray().apply { + groups.forEach { put(encodeGroup(it)) } + } + + fun decodeGroups(array: JSONArray): List = array.mapObjects(::decodeGroup) + + fun encodeRules(rules: List): JSONArray = JSONArray().apply { + rules.forEach { put(encodeRule(it)) } + } + + fun decodeRules(array: JSONArray): List = array.mapObjects(::decodeRule) + + fun encodeSettings(settings: List): JSONArray = JSONArray().apply { + settings.forEach { put(encodeSetting(it)) } + } + + fun decodeSettings(array: JSONArray): List = array.mapObjects(::decodeSetting) + + fun encodeProfile(profile: ProxyEntity): JSONObject = JSONObject().apply { + put("id", profile.id) + put("groupId", profile.groupId) + put("type", profile.type) + put("userOrder", profile.userOrder) + put("tx", profile.tx) + put("rx", profile.rx) + put("status", profile.status) + put("ping", profile.ping) + put("uuid", profile.uuid) + putNullable("error", profile.error) + put("dirty", profile.dirty) + put("bean", encodeBytes(KryoConverters.serialize(profile.requireBean()))) + } + + fun decodeProfile(json: JSONObject): ProxyEntity = ProxyEntity( + id = json.getLong("id"), + groupId = json.getLong("groupId"), + type = json.getInt("type"), + userOrder = json.getLong("userOrder"), + tx = json.getLong("tx"), + rx = json.getLong("rx"), + status = json.getInt("status"), + ping = json.getInt("ping"), + uuid = json.getString("uuid"), + error = json.optNullableString("error"), + ).apply { + dirty = json.optBoolean("dirty", false) + putByteArray(decodeBytes(json.getString("bean"))) + } + + fun encodeGroup(group: ProxyGroup): JSONObject = JSONObject().apply { + put("id", group.id) + put("userOrder", group.userOrder) + put("ungrouped", group.ungrouped) + putNullable("name", group.name) + put("type", group.type) + put("order", group.order) + put("isSelector", group.isSelector) + put("frontProxy", group.frontProxy) + put("landingProxy", group.landingProxy) + putNullable("subscription", group.subscription?.let { encodeBytes(KryoConverters.serialize(it)) }) + } + + fun decodeGroup(json: JSONObject): ProxyGroup = ProxyGroup( + id = json.getLong("id"), + userOrder = json.getLong("userOrder"), + ungrouped = json.getBoolean("ungrouped"), + name = json.optNullableString("name"), + type = json.getInt("type"), + subscription = json.optNullableString("subscription")?.let { + KryoConverters.subscriptionDeserialize(decodeBytes(it)) + }, + order = json.getInt("order"), + isSelector = json.optBoolean("isSelector", false), + frontProxy = json.optLong("frontProxy", -1L), + landingProxy = json.optLong("landingProxy", -1L), + ) + + fun encodeRule(rule: RuleEntity): JSONObject = JSONObject().apply { + put("id", rule.id) + put("name", rule.name) + put("config", rule.config) + put("userOrder", rule.userOrder) + put("enabled", rule.enabled) + put("domains", rule.domains) + put("ip", rule.ip) + put("port", rule.port) + put("sourcePort", rule.sourcePort) + put("network", rule.network) + put("source", rule.source) + put("protocol", rule.protocol) + put("ruleset", rule.ruleset) + put("outbound", rule.outbound) + put("packages", JSONArray().apply { rule.packages.sorted().forEach { put(it) } }) + } + + fun decodeRule(json: JSONObject): RuleEntity = RuleEntity( + id = json.getLong("id"), + name = json.getString("name"), + config = json.getString("config"), + userOrder = json.getLong("userOrder"), + enabled = json.getBoolean("enabled"), + domains = json.getString("domains"), + ip = json.getString("ip"), + port = json.getString("port"), + sourcePort = json.getString("sourcePort"), + network = json.getString("network"), + source = json.getString("source"), + protocol = json.getString("protocol"), + ruleset = json.getString("ruleset"), + outbound = json.getLong("outbound"), + packages = json.getJSONArray("packages").mapStrings().toSet(), + ) + + fun encodeSetting(setting: KeyValuePair): JSONObject = JSONObject().apply { + put("key", setting.key) + put("valueType", setting.valueType) + put("value", encodeBytes(setting.value)) + } + + fun decodeSetting(json: JSONObject): KeyValuePair = KeyValuePair().apply { + key = json.getString("key") + valueType = json.getInt("valueType") + value = decodeBytes(json.getString("value")) + } + + private fun encodeBytes(bytes: ByteArray): String { + return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes) + } + + private fun decodeBytes(text: String): ByteArray = Base64.getUrlDecoder().decode(text) + + private fun JSONObject.putNullable(name: String, value: Any?) { + put(name, value ?: JSONObject.NULL) + } + + private fun JSONObject.optNullableString(name: String): String? { + return if (!has(name) || isNull(name)) null else getString(name) + } + + private fun JSONArray.mapObjects(transform: (JSONObject) -> T): List { + return (0 until length()).map { index -> transform(getJSONObject(index)) } + } + + private fun JSONArray.mapStrings(): List { + return (0 until length()).map { index -> getString(index) } + } +} diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFragment.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFragment.kt index 9b8e0bbfa..9a88cc5f8 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFragment.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/BackupFragment.kt @@ -43,6 +43,11 @@ import java.util.zip.ZipEntry import java.util.zip.ZipInputStream import java.util.zip.ZipOutputStream +internal fun backupFileName(now: Date = Date()): String { + val timestamp = java.text.SimpleDateFormat("yyyyMMdd_HHmmss", Locale.US).format(now) + return "nekobox_backup_$timestamp.json" +} + class BackupFragment : NamedFragment(R.layout.layout_backup) { private lateinit var binding: LayoutBackupBinding @@ -111,7 +116,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { onMainDispatcher { startFilesForResult( exportSettings, - "nekobox_backup_${Date().toLocaleString()}.json", + backupFileName(), ) } } @@ -127,7 +132,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { val shareDir = File(app.cacheDir, "share").apply { mkdirs() } val cacheFile = File( shareDir, - "nekobox_backup_${Date().toLocaleString()}.json", + backupFileName(), ) cacheFile.writeBytes(backupData) onMainDispatcher { @@ -221,8 +226,8 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { .addPathSegment(fileName) .build() - Logs.d("WebDAV backup - Directory URL: $dirUrl") - Logs.d("WebDAV backup - File URL: $fileUrl") + Logs.d("WebDAV backup - Directory URL: ${redactedWebDavUrlForLog(dirUrl)}") + Logs.d("WebDAV backup - File URL: ${redactedWebDavUrlForLog(fileUrl)}") // first check whether the directory exists val propfindRequest = Request.Builder() @@ -247,8 +252,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { 401 -> throw Exception("Authentication failed") else -> { if (!response.isSuccessful) { - val errorBody = response.body?.string() - Logs.e("WebDAV backup - PROPFIND error: $errorBody") + Logs.e("WebDAV backup - PROPFIND failed: ${response.code} ${response.message}") throw Exception("Failed to check directory (${response.code}): ${response.message}") } } @@ -272,8 +276,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { client.newCall(mkcolRequest).execute().use { response -> if (!response.isSuccessful) { - val errorBody = response.body?.string() - Logs.e("WebDAV backup - MKCOL error: $errorBody") + Logs.e("WebDAV backup - MKCOL failed: ${response.code} ${response.message}") throw Exception("Failed to create directory (${response.code}): ${response.message}") } } @@ -296,9 +299,8 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { client.newCall(putRequest).execute().use { response -> if (!response.isSuccessful) { - val errorBody = response.body?.string() - Logs.e("WebDAV backup - PUT error: $errorBody") - throw Exception("Upload failed (${response.code}): ${response.message}\n$errorBody") + Logs.e("WebDAV backup - PUT failed: ${response.code} ${response.message}") + throw Exception("Upload failed (${response.code}): ${response.message}") } Logs.d("WebDAV backup - Upload successful") } @@ -351,7 +353,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { } }.build() - Logs.d("WebDAV restore - Directory URL: $dirUrl") + Logs.d("WebDAV restore - Directory URL: ${redactedWebDavUrlForLog(dirUrl)}") // first list the directory contents to find the latest backup file val propfindRequest = Request.Builder() @@ -369,14 +371,14 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { // get the latest backup file name val latestBackup = client.newCall(propfindRequest).execute().use { response -> - if (!response.isSuccessful && response.code != 207) { - val errorBody = response.body?.string() - Logs.e("WebDAV restore - PROPFIND error: $errorBody") - throw Exception("Failed to list directory: ${response.message}") + if (response.code != 207) { + Logs.e("WebDAV restore - PROPFIND failed: ${response.code} ${response.message}") + throw Exception("Failed to list directory (${response.code}): ${response.message}") } - val responseBody = response.body?.string() ?: throw Exception("Empty response") - Logs.d("WebDAV restore - Directory listing: $responseBody") + val responseBody = response.body?.byteStream() + ?.use { it.readBytesBounded().toString(Charsets.UTF_8) } + ?: throw Exception("Empty response") val patterns = listOf( """[^<]*?nekobox_backup_[^<]*?\d{8}_\d{6}\.(json|zip)""".toRegex(), @@ -390,7 +392,6 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { val matches = pattern.findAll(responseBody) matches.forEach { match -> val href = match.value - Logs.d("WebDAV restore - Found backup file with pattern ${pattern.pattern}: $href") val fileName = """nekobox_backup_[^<]*?\d{8}_\d{6}\.(json|zip)""".toRegex() .find(href)?.value if (fileName != null) { @@ -400,7 +401,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { if (backupFiles.isNotEmpty()) break } - Logs.d("WebDAV restore - Found ${backupFiles.size} backup files: ${backupFiles.joinToString()}") + Logs.d("WebDAV restore - Found ${backupFiles.size} candidate backup files") backupFiles.maxByOrNull { fileName -> """(\d{8}_\d{6})""".toRegex().find(fileName)?.value ?: "" @@ -411,7 +412,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { val fileUrl = dirUrl.newBuilder() .addPathSegment(latestBackup) .build() - Logs.d("WebDAV restore - File URL: $fileUrl") + Logs.d("WebDAV restore - File URL: ${redactedWebDavUrlForLog(fileUrl)}") val getRequest = Request.Builder() .url(fileUrl) @@ -427,8 +428,7 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { val content = client.newCall(getRequest).execute().use { response -> if (!response.isSuccessful) { - val errorBody = response.body?.string() - Logs.e("WebDAV restore - GET error: $errorBody") + Logs.e("WebDAV restore - GET failed: ${response.code} ${response.message}") throw Exception("Download failed (${response.code}): ${response.message}") } response.body?.byteStream()?.use { it.readBytesBounded() } @@ -545,45 +545,16 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { DataStore.configurationStore.awaitWrites() } val out = JSONObject().apply { - put("version", 1) + put("version", BackupFormatV2.VERSION) if (profile) { - put( - "profiles", - JSONArray().apply { - SagerDatabase.proxyDao.getAll().forEach { - put(it.toBase64Str()) - } - }, - ) - - put( - "groups", - JSONArray().apply { - SagerDatabase.groupDao.allGroups().forEach { - put(it.toBase64Str()) - } - }, - ) + put("profiles", BackupFormatV2.encodeProfiles(SagerDatabase.proxyDao.getAll())) + put("groups", BackupFormatV2.encodeGroups(SagerDatabase.groupDao.allGroups())) } if (rule) { - put( - "rules", - JSONArray().apply { - SagerDatabase.rulesDao.allRules().forEach { - put(it.toBase64Str()) - } - }, - ) + put("rules", BackupFormatV2.encodeRules(SagerDatabase.rulesDao.allRules())) } if (setting) { - put( - "settings", - JSONArray().apply { - PublicDatabase.kvPairDao.all().forEach { - put(it.toBase64Str()) - } - }, - ) + put("settings", BackupFormatV2.encodeSettings(PublicDatabase.kvPairDao.all())) } } @@ -716,24 +687,41 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { // // profiles and groups are one logical section (a config backup always contains both); // they are decoded and committed together so neither is left without the other. + val version = content.optInt("version", 1) + if (version != 1 && version != BackupFormatV2.VERSION) { + error("Unsupported backup version: $version") + } + val importConfigs = profile && content.has("profiles") val profiles = if (importConfigs) { - decodeArray(content.getJSONArray("profiles")) { ProxyEntity.CREATOR.createFromParcel(it) } + when (version) { + BackupFormatV2.VERSION -> BackupFormatV2.decodeProfiles(content.getJSONArray("profiles")) + else -> decodeArray(content.getJSONArray("profiles")) { ProxyEntity.CREATOR.createFromParcel(it) } + } } else { null } val groups = if (importConfigs) { - decodeArray(content.getJSONArray("groups")) { ProxyGroup.CREATOR.createFromParcel(it) } + when (version) { + BackupFormatV2.VERSION -> BackupFormatV2.decodeGroups(content.getJSONArray("groups")) + else -> decodeArray(content.getJSONArray("groups")) { ProxyGroup.CREATOR.createFromParcel(it) } + } } else { null } val rules = if (rule && content.has("rules")) { - decodeArray(content.getJSONArray("rules")) { ParcelizeBridge.createRule(it) } + when (version) { + BackupFormatV2.VERSION -> BackupFormatV2.decodeRules(content.getJSONArray("rules")) + else -> decodeArray(content.getJSONArray("rules")) { ParcelizeBridge.createRule(it) } + } } else { null } val settings = if (setting && content.has("settings")) { - decodeArray(content.getJSONArray("settings")) { KeyValuePair.CREATOR.createFromParcel(it) } + when (version) { + BackupFormatV2.VERSION -> BackupFormatV2.decodeSettings(content.getJSONArray("settings")) + else -> decodeArray(content.getJSONArray("settings")) { KeyValuePair.CREATOR.createFromParcel(it) } + } } else { null } @@ -773,9 +761,9 @@ class BackupFragment : NamedFragment(R.layout.layout_backup) { * throws (aborting the whole import) before the caller commits any reset+insert. * * NOTE: Parcel.unmarshall on imported bytes is a known deserialization hazard and an - * unstable persistence format (see Plan 014). This retains backward compatibility with - * existing backups; the validate-then-commit ordering above removes the partial-wipe - * risk. Migrating the encoding off Parcel is tracked as follow-up work. + * unstable persistence format. This legacy path is retained only for version 1 backups; + * new exports use the explicit Parcel-free version 2 schema. The validate-then-commit + * ordering above removes the partial-wipe risk for legacy imports. */ private fun decodeArray(array: JSONArray, create: (Parcel) -> T): List { val out = ArrayList(array.length()) diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/ConfigurationFragment.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/ConfigurationFragment.kt index 762bc0e83..9ac56d1ae 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/ui/ConfigurationFragment.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/ConfigurationFragment.kt @@ -1681,10 +1681,9 @@ class ConfigurationFragment @JvmOverloads constructor( if (::undoManager.isInitialized) { undoManager.flush() } + val oldProfile = configurationList[profile.id] configurationList[profile.id] = profile notifyItemChanged(index) - // - val oldProfile = configurationList[profile.id] if (noTraffic && oldProfile != null) { runOnDefaultDispatcher { onUpdated( @@ -1879,8 +1878,11 @@ class ConfigurationFragment @JvmOverloads constructor( if (update) { ProfileManager.postUpdate(lastSelected) if (DataStore.serviceState.canStop && reloadAccess.tryLock()) { - SagerNet.reloadService(proxyEntity.id) - reloadAccess.unlock() + try { + SagerNet.reloadService(proxyEntity.id) + } finally { + reloadAccess.unlock() + } } } else if (SagerNet.isTv) { if (DataStore.serviceState.started) { diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/WebDAVSecurity.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/WebDAVSecurity.kt index 5b3379d65..1eb6fe6c2 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/ui/WebDAVSecurity.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/WebDAVSecurity.kt @@ -13,6 +13,13 @@ import okhttp3.HttpUrl.Companion.toHttpUrlOrNull * would transmit both the credentials and the secret-bearing backup in cleartext, * so only TLS (`https://`) endpoints are accepted. */ +internal fun redactedWebDavUrlForLog(url: HttpUrl): String { + val defaultPort = HttpUrl.defaultPort(url.scheme) + val host = if (url.host.contains(':')) "[${url.host}]" else url.host + val port = if (url.port != defaultPort) ":${url.port}" else "" + return "${url.scheme}://$host$port/" +} + object WebDAVSecurity { /** diff --git a/app/src/main/java/io/nekohasekai/sagernet/ui/profile/ProfileSettingsActivity.kt b/app/src/main/java/io/nekohasekai/sagernet/ui/profile/ProfileSettingsActivity.kt index b52d51a71..13649009a 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/ui/profile/ProfileSettingsActivity.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/ui/profile/ProfileSettingsActivity.kt @@ -34,6 +34,7 @@ import io.nekohasekai.sagernet.* import io.nekohasekai.sagernet.database.DataStore import io.nekohasekai.sagernet.database.GroupManager import io.nekohasekai.sagernet.database.ProfileManager +import io.nekohasekai.sagernet.database.ProxyGroup import io.nekohasekai.sagernet.database.SagerDatabase import io.nekohasekai.sagernet.database.preference.OnPreferenceDataStoreChangeListener import io.nekohasekai.sagernet.databinding.LayoutGroupItemBinding @@ -90,6 +91,8 @@ abstract class ProfileSettingsActivity( val proxyEntity by lazy { SagerDatabase.proxyDao.getById(DataStore.editingId) } protected var isSubscription by Delegates.notNull() + private var canMoveToOtherBasicGroup = false + private var moveTargetGroups: List = emptyList() override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) @@ -118,19 +121,27 @@ abstract class ProfileSettingsActivity( runOnDefaultDispatcher { if (editingId == 0L) { DataStore.editingGroup = DataStore.selectedGroupForImport() + canMoveToOtherBasicGroup = false + moveTargetGroups = emptyList() createEntity().applyDefaultValues().init() } else { - if (proxyEntity == null) { + val entity = proxyEntity + if (entity == null) { onMainDispatcher { finish() } return@runOnDefaultDispatcher } - DataStore.editingGroup = proxyEntity!!.groupId - (proxyEntity!!.requireBean() as T).init() + DataStore.editingGroup = entity.groupId + val groups = SagerDatabase.groupDao.allGroups() + moveTargetGroups = groups.filter { it.type == GroupType.BASIC && it.id != entity.groupId } + canMoveToOtherBasicGroup = groups.firstOrNull { it.id == entity.groupId }?.type == + GroupType.BASIC && moveTargetGroups.isNotEmpty() + (entity.requireBean() as T).init() } onMainDispatcher { + invalidateOptionsMenu() supportFragmentManager.beginTransaction() .replace(R.id.settings, MyPreferenceFragmentCompat()) .commit() @@ -162,13 +173,7 @@ abstract class ProfileSettingsActivity( override fun onCreateOptionsMenu(menu: Menu): Boolean { menuInflater.inflate(R.menu.profile_config_menu, menu) menu.findItem(R.id.action_move)?.apply { - if (DataStore.editingId != 0L && // not new profile - SagerDatabase.groupDao.getById(DataStore.editingGroup)?.type == GroupType.BASIC && // not in subscription group - SagerDatabase.groupDao.allGroups() - .filter { it.type == GroupType.BASIC }.size > 1 // have other basic group - ) { - isVisible = true - } + isVisible = DataStore.editingId != 0L && canMoveToOtherBasicGroup } menu.findItem(R.id.action_create_shortcut)?.apply { if (Build.VERSION.SDK_INT >= 26 && DataStore.editingId != 0L) { @@ -342,13 +347,13 @@ abstract class ProfileSettingsActivity( R.id.action_move -> { val activity = requireActivity() as ProfileSettingsActivity<*> - val view = LinearLayout(context).apply { - val ent = activity.proxyEntity!! - orientation = LinearLayout.VERTICAL + val ent = activity.proxyEntity!! + val groups = activity.moveTargetGroups + if (groups.isNotEmpty()) { + val view = LinearLayout(context).apply { + orientation = LinearLayout.VERTICAL - SagerDatabase.groupDao.allGroups() - .filter { it.type == GroupType.BASIC && it.id != ent.groupId } - .forEach { group -> + groups.forEach { group -> LayoutGroupItemBinding.inflate(layoutInflater, this, true).apply { edit.isVisible = false options.isVisible = false @@ -370,11 +375,12 @@ abstract class ProfileSettingsActivity( } } } + } + val scrollView = ScrollView(context).apply { + addView(view) + } + MaterialAlertDialogBuilder(activity).setView(scrollView).show() } - val scrollView = ScrollView(context).apply { - addView(view) - } - MaterialAlertDialogBuilder(activity).setView(scrollView).show() true } diff --git a/app/src/main/java/io/nekohasekai/sagernet/widget/GroupPreference.kt b/app/src/main/java/io/nekohasekai/sagernet/widget/GroupPreference.kt index 25a662f9d..5ef3cff09 100644 --- a/app/src/main/java/io/nekohasekai/sagernet/widget/GroupPreference.kt +++ b/app/src/main/java/io/nekohasekai/sagernet/widget/GroupPreference.kt @@ -4,6 +4,8 @@ import android.content.Context import android.util.AttributeSet import io.nekohasekai.sagernet.R import io.nekohasekai.sagernet.database.SagerDatabase +import io.nekohasekai.sagernet.ktx.runOnDefaultDispatcher +import io.nekohasekai.sagernet.ktx.runOnMainDispatcher import moe.matsuri.nb4a.ui.SimpleMenuPreference class GroupPreference @@ -13,17 +15,30 @@ class GroupPreference defStyle: Int = R.attr.dropdownPreferenceStyle, ) : SimpleMenuPreference(context, attrs, defStyle, 0) { + private val groupNames = mutableMapOf() + init { - val groups = SagerDatabase.groupDao.allGroups() + val wasEnabled = isEnabled + entries = emptyArray() + entryValues = emptyArray() + isEnabled = false - entries = groups.map { it.displayName() }.toTypedArray() - entryValues = groups.map { "${it.id}" }.toTypedArray() + runOnDefaultDispatcher { + val groups = SagerDatabase.groupDao.allGroups() + runOnMainDispatcher { + groupNames.clear() + groupNames.putAll(groups.associate { it.id to it.displayName() }) + entries = groups.map { it.displayName() }.toTypedArray() + entryValues = groups.map { "${it.id}" }.toTypedArray() + isEnabled = wasEnabled + notifyChanged() + } + } } override fun getSummary(): CharSequence? { if (!value.isNullOrBlank() && value != "0") { - return SagerDatabase.groupDao.getById(value.toLong())?.displayName() - ?: super.getSummary() + return groupNames[value.toLongOrNull()] ?: super.getSummary() } return super.getSummary() } diff --git a/app/src/test/java/io/nekohasekai/sagernet/bg/SubscriptionUpdaterScheduleTest.kt b/app/src/test/java/io/nekohasekai/sagernet/bg/SubscriptionUpdaterScheduleTest.kt new file mode 100644 index 000000000..0905dae96 --- /dev/null +++ b/app/src/test/java/io/nekohasekai/sagernet/bg/SubscriptionUpdaterScheduleTest.kt @@ -0,0 +1,65 @@ +package io.nekohasekai.sagernet.bg + +import org.junit.Assert.assertEquals +import org.junit.Test + +class SubscriptionUpdaterScheduleTest { + + @Test + fun overdueSubscription_schedulesImmediately() { + val schedule = computeSubscriptionWorkSchedule( + listOf(SubscriptionScheduleInput(lastUpdated = 1_000, autoUpdateDelay = 15)), + nowSeconds = 1_900, + )!! + + assertEquals(15L, schedule.intervalMinutes) + assertEquals(0L, schedule.initialDelaySeconds) + } + + @Test + fun nearFutureSubscription_usesSecondsUntilDue() { + val schedule = computeSubscriptionWorkSchedule( + listOf(SubscriptionScheduleInput(lastUpdated = 1_000, autoUpdateDelay = 15)), + nowSeconds = 1_870, + )!! + + assertEquals(15L, schedule.intervalMinutes) + assertEquals(30L, schedule.initialDelaySeconds) + } + + @Test + fun farFutureSubscription_preservesSecondsUntilDue() { + val schedule = computeSubscriptionWorkSchedule( + listOf(SubscriptionScheduleInput(lastUpdated = 1_000, autoUpdateDelay = 60)), + nowSeconds = 1_100, + )!! + + assertEquals(60L, schedule.intervalMinutes) + assertEquals(3_500L, schedule.initialDelaySeconds) + } + + @Test + fun delayBelowWorkManagerMinimum_isCoercedForIntervalAndDueTime() { + val schedule = computeSubscriptionWorkSchedule( + listOf(SubscriptionScheduleInput(lastUpdated = 1_000, autoUpdateDelay = 5)), + nowSeconds = 1_870, + )!! + + assertEquals(15L, schedule.intervalMinutes) + assertEquals(30L, schedule.initialDelaySeconds) + } + + @Test + fun multipleSubscriptions_useSoonestDueSubscription() { + val schedule = computeSubscriptionWorkSchedule( + listOf( + SubscriptionScheduleInput(lastUpdated = 1_000, autoUpdateDelay = 60), + SubscriptionScheduleInput(lastUpdated = 2_000, autoUpdateDelay = 15), + ), + nowSeconds = 2_870, + )!! + + assertEquals(15L, schedule.intervalMinutes) + assertEquals(30L, schedule.initialDelaySeconds) + } +} diff --git a/app/src/test/java/io/nekohasekai/sagernet/ui/BackupFileNameTest.kt b/app/src/test/java/io/nekohasekai/sagernet/ui/BackupFileNameTest.kt new file mode 100644 index 000000000..955422b7b --- /dev/null +++ b/app/src/test/java/io/nekohasekai/sagernet/ui/BackupFileNameTest.kt @@ -0,0 +1,27 @@ +package io.nekohasekai.sagernet.ui + +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Test +import java.util.Date + +class BackupFileNameTest { + + @Test + fun fixedDate_hasStableAsciiTimestampShape() { + val name = backupFileName(Date(0)) + + assertTrue(Regex("nekobox_backup_\\d{8}_\\d{6}\\.json").matches(name)) + } + + @Test + fun generatedName_hasNoUnsafeFilenameCharacters() { + val name = backupFileName(Date(0)) + + assertFalse(name.contains('/')) + assertFalse(name.contains('\\')) + assertFalse(name.contains(':')) + assertFalse(name.contains('\n')) + assertFalse(name.contains(' ')) + } +} diff --git a/app/src/test/java/io/nekohasekai/sagernet/ui/BackupFormatV2Test.kt b/app/src/test/java/io/nekohasekai/sagernet/ui/BackupFormatV2Test.kt new file mode 100644 index 000000000..d7c1c24e1 --- /dev/null +++ b/app/src/test/java/io/nekohasekai/sagernet/ui/BackupFormatV2Test.kt @@ -0,0 +1,194 @@ +package io.nekohasekai.sagernet.ui + +import io.nekohasekai.sagernet.GroupType +import io.nekohasekai.sagernet.database.ProxyEntity +import io.nekohasekai.sagernet.database.ProxyGroup +import io.nekohasekai.sagernet.database.RuleEntity +import io.nekohasekai.sagernet.database.SubscriptionBean +import io.nekohasekai.sagernet.database.preference.KeyValuePair +import io.nekohasekai.sagernet.fmt.socks.SOCKSBean +import io.nekohasekai.sagernet.fmt.trojan.TrojanBean +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import org.robolectric.annotation.Config + +@RunWith(RobolectricTestRunner::class) +@Config(sdk = [33], application = android.app.Application::class) +class BackupFormatV2Test { + + @Test + fun profileRoundTrip_preservesTypeAndBean() { + val profile = ProxyEntity( + id = 10L, + groupId = 20L, + userOrder = 30L, + tx = 40L, + rx = 50L, + status = 1, + ping = 123, + uuid = "profile-uuid", + error = "last error", + ).apply { + putBean( + SOCKSBean().apply { + serverAddress = "192.0.2.10" + serverPort = 1080 + username = "user" + password = "pass" + protocol = 2 + name = "socks node" + initializeDefaultValues() + }, + ) + } + + val decoded = BackupFormatV2.decodeProfile(BackupFormatV2.encodeProfile(profile)) + val bean = decoded.requireBean() as SOCKSBean + + assertEquals(profile.id, decoded.id) + assertEquals(profile.groupId, decoded.groupId) + assertEquals(profile.type, decoded.type) + assertEquals(profile.userOrder, decoded.userOrder) + assertEquals(profile.tx, decoded.tx) + assertEquals(profile.rx, decoded.rx) + assertEquals(profile.status, decoded.status) + assertEquals(profile.ping, decoded.ping) + assertEquals(profile.uuid, decoded.uuid) + assertEquals(profile.error, decoded.error) + assertEquals("192.0.2.10", bean.serverAddress) + assertEquals(1080, bean.serverPort) + assertEquals("user", bean.username) + assertEquals("pass", bean.password) + assertEquals(2, bean.protocol) + assertEquals("socks node", bean.name) + } + + @Test + fun profileRoundTrip_trojanBeanPreservesFields() { + val profile = ProxyEntity(id = 11L, groupId = 22L).apply { + putBean( + TrojanBean().apply { + serverAddress = "example.com" + serverPort = 443 + password = "secret" + name = "trojan node" + initializeDefaultValues() + }, + ) + } + + val decoded = BackupFormatV2.decodeProfile(BackupFormatV2.encodeProfile(profile)) + val bean = decoded.requireBean() as TrojanBean + + assertEquals(ProxyEntity.TYPE_TROJAN, decoded.type) + assertEquals("example.com", bean.serverAddress) + assertEquals(443, bean.serverPort) + assertEquals("secret", bean.password) + assertEquals("trojan node", bean.name) + } + + @Test + fun groupsRoundTrip_preservesBasicAndSubscriptionGroups() { + val subscription = SubscriptionBean().apply { + initializeDefaultValues() + link = "https://example.com/sub" + autoUpdate = true + autoUpdateDelay = 30 + lastUpdated = 1234 + } + val groups = listOf( + ProxyGroup( + id = 1L, + userOrder = 2L, + ungrouped = true, + name = null, + type = GroupType.BASIC, + isSelector = true, + frontProxy = 10L, + landingProxy = 11L, + ), + ProxyGroup( + id = 3L, + userOrder = 4L, + name = "subscription group", + type = GroupType.SUBSCRIPTION, + subscription = subscription, + ), + ) + + val decoded = BackupFormatV2.decodeGroups(BackupFormatV2.encodeGroups(groups)) + + assertEquals(2, decoded.size) + assertEquals(1L, decoded[0].id) + assertEquals(true, decoded[0].ungrouped) + assertNull(decoded[0].name) + assertEquals(true, decoded[0].isSelector) + assertEquals(10L, decoded[0].frontProxy) + assertEquals(11L, decoded[0].landingProxy) + assertEquals(GroupType.SUBSCRIPTION, decoded[1].type) + assertEquals("subscription group", decoded[1].name) + assertEquals("https://example.com/sub", decoded[1].subscription!!.link) + assertEquals(true, decoded[1].subscription!!.autoUpdate) + assertEquals(30, decoded[1].subscription!!.autoUpdateDelay) + assertEquals(1234, decoded[1].subscription!!.lastUpdated) + } + + @Test + fun ruleRoundTrip_preservesPackages() { + val rule = RuleEntity( + id = 5L, + name = "rule", + config = "config", + userOrder = 6L, + enabled = true, + domains = "example.com", + ip = "192.0.2.0/24", + port = "443", + sourcePort = "1000:2000", + network = "tcp", + source = "10.0.0.1", + protocol = "tls", + ruleset = "geoip-cn", + outbound = -1L, + packages = setOf("com.example.one", "com.example.two"), + ) + + val decoded = BackupFormatV2.decodeRule(BackupFormatV2.encodeRule(rule)) + + assertEquals(rule.id, decoded.id) + assertEquals(rule.name, decoded.name) + assertEquals(rule.config, decoded.config) + assertEquals(rule.userOrder, decoded.userOrder) + assertEquals(rule.enabled, decoded.enabled) + assertEquals(rule.domains, decoded.domains) + assertEquals(rule.ip, decoded.ip) + assertEquals(rule.port, decoded.port) + assertEquals(rule.sourcePort, decoded.sourcePort) + assertEquals(rule.network, decoded.network) + assertEquals(rule.source, decoded.source) + assertEquals(rule.protocol, decoded.protocol) + assertEquals(rule.ruleset, decoded.ruleset) + assertEquals(rule.outbound, decoded.outbound) + assertEquals(rule.packages, decoded.packages) + } + + @Test + fun settingsRoundTrip_preservesStringAndStringSetValues() { + val settings = listOf( + KeyValuePair("string-key").put("value"), + KeyValuePair("set-key").put(setOf("one", "two")), + ) + + val decoded = BackupFormatV2.decodeSettings(BackupFormatV2.encodeSettings(settings)) + + assertEquals("string-key", decoded[0].key) + assertEquals(KeyValuePair.TYPE_STRING, decoded[0].valueType) + assertEquals("value", decoded[0].string) + assertEquals("set-key", decoded[1].key) + assertEquals(KeyValuePair.TYPE_STRING_SET, decoded[1].valueType) + assertEquals(setOf("one", "two"), decoded[1].stringSet) + } +} diff --git a/app/src/test/java/io/nekohasekai/sagernet/ui/WebDAVSecurityTest.kt b/app/src/test/java/io/nekohasekai/sagernet/ui/WebDAVSecurityTest.kt new file mode 100644 index 000000000..301bcfaa6 --- /dev/null +++ b/app/src/test/java/io/nekohasekai/sagernet/ui/WebDAVSecurityTest.kt @@ -0,0 +1,36 @@ +package io.nekohasekai.sagernet.ui + +import okhttp3.HttpUrl.Companion.toHttpUrl +import org.junit.Assert.assertEquals +import org.junit.Test + +class WebDAVSecurityTest { + + @Test + fun redactedUrl_removesPath() { + val redacted = redactedWebDavUrlForLog("https://example.com/secret/path/file.zip".toHttpUrl()) + + assertEquals("https://example.com/", redacted) + } + + @Test + fun redactedUrl_removesCredentialsAndQuery() { + val redacted = redactedWebDavUrlForLog("https://user:pass@example.com/private?token=x".toHttpUrl()) + + assertEquals("https://example.com/", redacted) + } + + @Test + fun redactedUrl_preservesNonDefaultPort() { + val redacted = redactedWebDavUrlForLog("https://example.com:8443/a/b".toHttpUrl()) + + assertEquals("https://example.com:8443/", redacted) + } + + @Test + fun redactedUrl_bracketsIpv6Host() { + val redacted = redactedWebDavUrlForLog("https://[2001:db8::1]:8443/a/b".toHttpUrl()) + + assertEquals("https://[2001:db8::1]:8443/", redacted) + } +} diff --git a/buildScript/lib/assets.sh b/buildScript/lib/assets.sh index 7c07c55d8..1918d60c3 100755 --- a/buildScript/lib/assets.sh +++ b/buildScript/lib/assets.sh @@ -1,28 +1,51 @@ #!/bin/bash set -e +set -o pipefail -DIR=app/src/main/assets/sing-box -rm -rf $DIR -mkdir -p $DIR -cd $DIR +GEOIP_VERSION="${GEOIP_VERSION:-20260612}" +GEOIP_SHA256="${GEOIP_SHA256:-71484cf35bb48453e26bcc3373a0988a2536588f8e3ca96cda59ff742af6c392}" +GEOSITE_VERSION="${GEOSITE_VERSION:-20260625041655}" +GEOSITE_SHA256="${GEOSITE_SHA256:-7e4220f1700bcb63204b11c9a5a07d1c315d1262c3e0049f23d548b0b7b0343a}" + +sha256_tool() { + if command -v sha256sum >/dev/null 2>&1; then sha256sum "$1" | awk '{print $1}' + else shasum -a 256 "$1" | awk '{print $1}'; fi +} -get_latest_release() { - curl --silent "https://api.github.com/repos/$1/releases/latest" | # Get latest release from GitHub api - grep '"tag_name":' | # Get tag line - sed -E 's/.*"([^"]+)".*/\1/' # Pluck JSON value +download_verified() { + local url="$1" output="$2" expected="$3" tmp actual + tmp="${output}.download" + rm -f "$tmp" + curl -fL --retry 3 --retry-delay 2 --max-time 300 "$url" -o "$tmp" + actual="$(sha256_tool "$tmp")" + if [ "$expected" != "$actual" ]; then + rm -f "$tmp" + echo "Error: checksum mismatch for $output (expected $expected, got $actual)" >&2 + exit 1 + fi + mv "$tmp" "$output" } +DIR=app/src/main/assets/sing-box +rm -rf "$DIR" +mkdir -p "$DIR" +cd "$DIR" + #### -VERSION_GEOIP=`get_latest_release "SagerNet/sing-geoip"` -echo VERSION_GEOIP=$VERSION_GEOIP -echo -n $VERSION_GEOIP > geoip.version.txt -curl -fLSsO https://github.com/SagerNet/sing-geoip/releases/download/$VERSION_GEOIP/geoip.db +echo VERSION_GEOIP=$GEOIP_VERSION +echo -n "$GEOIP_VERSION" > geoip.version.txt +download_verified \ + "https://github.com/SagerNet/sing-geoip/releases/download/$GEOIP_VERSION/geoip.db" \ + geoip.db \ + "$GEOIP_SHA256" xz -9 geoip.db #### -VERSION_GEOSITE=`get_latest_release "SagerNet/sing-geosite"` -echo VERSION_GEOSITE=$VERSION_GEOSITE -echo -n $VERSION_GEOSITE > geosite.version.txt -curl -fLSsO https://github.com/SagerNet/sing-geosite/releases/download/$VERSION_GEOSITE/geosite.db +echo VERSION_GEOSITE=$GEOSITE_VERSION +echo -n "$GEOSITE_VERSION" > geosite.version.txt +download_verified \ + "https://github.com/SagerNet/sing-geosite/releases/download/$GEOSITE_VERSION/geosite.db" \ + geosite.db \ + "$GEOSITE_SHA256" xz -9 geosite.db diff --git a/buildScript/lib/naive.sh b/buildScript/lib/naive.sh index fb5100633..c61ddfc07 100755 --- a/buildScript/lib/naive.sh +++ b/buildScript/lib/naive.sh @@ -13,23 +13,57 @@ set -e set -o pipefail -# Pinned naiveproxy release. +# Pinned naiveproxy release. The SHA256 values below are for the downloaded +# plugin APKs for this exact version; update them in the same change as any +# NAIVE_VERSION bump. NAIVE_VERSION="${NAIVE_VERSION:-v149.0.7827.114-1}" +NAIVE_SHA256_ARM64_V8A="${NAIVE_SHA256_ARM64_V8A:-07f58c14849f3fb047d342fdc8e34d65a745a133f436469673f29624bba87f6a}" +NAIVE_SHA256_ARMEABI_V7A="${NAIVE_SHA256_ARMEABI_V7A:-be0e126d2631a0a4c8f9140595243f51a8c676c0756deb67144677ebfe7d7202}" +NAIVE_SHA256_X86="${NAIVE_SHA256_X86:-82a3b8ef29876ccaa6f7df4dc3dabfaa92eb954a7de8a3e0ff93f92afc17e9ca}" +NAIVE_SHA256_X86_64="${NAIVE_SHA256_X86_64:-7957af60ac3bedaf6bd35c172297bd9e730b90ac25f6bd26fa19a4591ceec13a}" BASE="https://github.com/klzgrad/naiveproxy/releases/download/${NAIVE_VERSION}" OUT="$(pwd)/app/executableSo" WORK="$(pwd)/.naive-build" mkdir -p "$WORK" +sha256_tool() { + if command -v sha256sum >/dev/null 2>&1; then sha256sum "$1" | awk '{print $1}' + else shasum -a 256 "$1" | awk '{print $1}'; fi +} + +expected_sha256_for_abi() { + case "$1" in + arm64-v8a) echo "$NAIVE_SHA256_ARM64_V8A" ;; + armeabi-v7a) echo "$NAIVE_SHA256_ARMEABI_V7A" ;; + x86) echo "$NAIVE_SHA256_X86" ;; + x86_64) echo "$NAIVE_SHA256_X86_64" ;; + *) echo "Error: unsupported ABI $1" >&2; exit 1 ;; + esac +} + +verify_sha256() { + local file="$1" expected="$2" actual + actual="$(sha256_tool "$file")" + if [ "$expected" != "$actual" ]; then + echo "Error: checksum mismatch for $(basename "$file") (expected $expected, got $actual)" >&2 + exit 1 + fi +} + # Map Android ABI -> naiveproxy plugin APK ABI tag (identical here). extract_abi() { local abi="$1" local apk="naiveproxy-plugin-${NAIVE_VERSION}-${abi}.apk" + local apk_path="$WORK/$apk" echo ">> fetching libnaive.so for $abi ($apk)" - curl -fL --retry 3 --retry-delay 2 --max-time 300 "$BASE/$apk" -o "$WORK/$apk" + curl -fL --retry 3 --retry-delay 2 --max-time 300 "$BASE/$apk" -o "$apk_path" + verify_sha256 "$apk_path" "$(expected_sha256_for_abi "$abi")" + mkdir -p "$OUT/$abi" - # The plugin APK ships the client as lib//libnaive.so. - unzip -o -j "$WORK/$apk" "lib/$abi/libnaive.so" -d "$OUT/$abi" >/dev/null + # The plugin APK ships the client as lib//libnaive.so. Extract only after + # the downloaded APK matches the pinned SHA256. + unzip -o -j "$apk_path" "lib/$abi/libnaive.so" -d "$OUT/$abi" >/dev/null if [ ! -f "$OUT/$abi/libnaive.so" ]; then echo "Error: libnaive.so not found in $apk" >&2 exit 1 diff --git a/libcore/http.go b/libcore/http.go index c737e2701..2e4a3bf6f 100644 --- a/libcore/http.go +++ b/libcore/http.go @@ -17,6 +17,7 @@ import ( "net/http" "net/url" "os" + "path/filepath" "strconv" "sync" "sync/atomic" @@ -31,6 +32,17 @@ import ( var errFailConnectSocks5 = errors.New("fail connect socks5") +const ( + defaultHTTPTimeout = 120 * time.Second + defaultHTTPDialTimeout = 30 * time.Second + // defaultHTTPRequestTimeout bounds the whole request including body reads, + // so GetContentLimited/WriteToLimited cannot block forever on a stalled body. + // It is generous enough for large rule-asset downloads over slow links. + defaultHTTPRequestTimeout = 10 * time.Minute + defaultHTTPStringLimit = 10 * 1024 * 1024 + defaultHTTPFileLimit = 256 * 1024 * 1024 +) + type HTTPClient interface { RestrictedTLS() ModernTLS() @@ -57,8 +69,11 @@ type HTTPRequest interface { type HTTPResponse interface { GetHeader(string) *StringBox GetContent() ([]byte, error) + GetContentLimited(limit int64) ([]byte, error) GetContentString() (*StringBox, error) + GetContentStringLimited(limit int64) (*StringBox, error) WriteTo(path string) error + WriteToLimited(path string, limit int64) error } var ( @@ -77,9 +92,15 @@ type httpClient struct { func NewHttpClient() HTTPClient { client := new(httpClient) + dialer := &net.Dialer{Timeout: defaultHTTPDialTimeout} client.h1h2Client.Transport = &client.h1h2Transport + client.h1h2Transport.DialContext = dialer.DialContext client.h1h2Transport.TLSClientConfig = &client.tls + client.h1h2Transport.TLSHandshakeTimeout = defaultHTTPTimeout + client.h1h2Transport.ResponseHeaderTimeout = defaultHTTPTimeout client.h1h2Transport.DisableKeepAlives = true + // Bound the full request (including body read) so callers cannot hang forever. + client.h1h2Client.Timeout = defaultHTTPRequestTimeout return client } @@ -115,7 +136,7 @@ func (c *httpClient) PinnedSHA256(sumHex string) { } func (c *httpClient) TrySocks5(port int32, username string, password string) { - dialer := new(net.Dialer) + dialer := &net.Dialer{Timeout: defaultHTTPDialTimeout} c.h1h2Transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { for { socksConn, err := dialer.DialContext(ctx, "tcp", "127.0.0.1:"+strconv.Itoa(int(port))) @@ -244,6 +265,7 @@ func (r *httpRequest) doH3Direct() (HTTPResponse, error) { func() (response *http.Response, err error) { request := r.request.Clone(context.Background()) echClient := &http.Client{ + Timeout: defaultHTTPRequestTimeout, Transport: &http.Transport{ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer @@ -267,6 +289,7 @@ func (r *httpRequest) doH3Direct() (HTTPResponse, error) { func() (response *http.Response, err error) { request := r.request.Clone(context.Background()) h3Client := &http.Client{ + Timeout: defaultHTTPRequestTimeout, Transport: &http3.Transport{ TLSClientConfig: r.tls.Clone(), QUICConfig: &quic.Config{ @@ -345,9 +368,10 @@ func (r *httpRequest) doH3Direct() (HTTPResponse, error) { type httpResponse struct { *http.Response - getContentOnce sync.Once - content []byte - contentError error + contentMu sync.Mutex + contentRead bool + content []byte + contentError error } func (h *httpResponse) errorString() string { @@ -366,15 +390,33 @@ func (h *httpResponse) GetHeader(key string) *StringBox { } func (h *httpResponse) GetContent() ([]byte, error) { - h.getContentOnce.Do(func() { - defer h.Body.Close() - h.content, h.contentError = io.ReadAll(h.Body) - }) + return h.GetContentLimited(defaultHTTPStringLimit) +} + +func (h *httpResponse) GetContentLimited(limit int64) ([]byte, error) { + h.contentMu.Lock() + defer h.contentMu.Unlock() + if h.contentRead { + if h.contentError != nil { + return nil, h.contentError + } + if int64(len(h.content)) > limit { + return nil, fmt.Errorf("HTTP response body exceeds %d bytes", limit) + } + return h.content, nil + } + defer h.Body.Close() + h.contentRead = true + h.content, h.contentError = readAllLimited(h.Body, limit) return h.content, h.contentError } func (h *httpResponse) GetContentString() (*StringBox, error) { - content, err := h.getContentString() + return h.GetContentStringLimited(defaultHTTPStringLimit) +} + +func (h *httpResponse) GetContentStringLimited(limit int64) (*StringBox, error) { + content, err := h.getContentStringLimited(limit) if err != nil { return nil, err } @@ -382,7 +424,11 @@ func (h *httpResponse) GetContentString() (*StringBox, error) { } func (h *httpResponse) getContentString() (string, error) { - content, err := h.GetContent() + return h.getContentStringLimited(defaultHTTPStringLimit) +} + +func (h *httpResponse) getContentStringLimited(limit int64) (string, error) { + content, err := h.GetContentLimited(limit) if err != nil { return "", err } @@ -390,12 +436,61 @@ func (h *httpResponse) getContentString() (string, error) { } func (h *httpResponse) WriteTo(path string) error { + return h.WriteToLimited(path, defaultHTTPFileLimit) +} + +func (h *httpResponse) WriteToLimited(path string, limit int64) (err error) { defer h.Body.Close() - file, err := os.Create(path) + dir, base := filepath.Split(path) + if dir == "" { + dir = "." + } + file, err := os.CreateTemp(dir, base+".*.tmp") if err != nil { return err } - defer file.Close() - _, err = io.Copy(file, h.Body) + tmpPath := file.Name() + defer func() { + if err != nil { + _ = os.Remove(tmpPath) + } + }() + defer func() { + if closeErr := file.Close(); err == nil { + err = closeErr + } + if err == nil { + err = os.Rename(tmpPath, path) + } + }() + _, err = copyLimited(file, h.Body, limit) return err } + +func readAllLimited(reader io.Reader, limit int64) ([]byte, error) { + if limit < 0 { + return nil, fmt.Errorf("invalid HTTP response limit %d", limit) + } + content, err := io.ReadAll(io.LimitReader(reader, limit+1)) + if err != nil { + return nil, err + } + if int64(len(content)) > limit { + return nil, fmt.Errorf("HTTP response body exceeds %d bytes", limit) + } + return content, nil +} + +func copyLimited(dst io.Writer, src io.Reader, limit int64) (int64, error) { + if limit < 0 { + return 0, fmt.Errorf("invalid HTTP response limit %d", limit) + } + written, err := io.Copy(dst, io.LimitReader(src, limit+1)) + if err != nil { + return written, err + } + if written > limit { + return written, fmt.Errorf("stream exceeds %d bytes", limit) + } + return written, nil +} diff --git a/libcore/io.go b/libcore/io.go index e4c1d3983..5c2181aa8 100644 --- a/libcore/io.go +++ b/libcore/io.go @@ -4,30 +4,51 @@ import ( "archive/zip" "io" "os" - "path/filepath" + "path/filepath" - "github.com/ulikunitz/xz" "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" + E "github.com/sagernet/sing/common/exceptions" + "github.com/ulikunitz/xz" ) -func Unxz(archive string, path string) error { +const defaultUnxzFileLimit = 256 * 1024 * 1024 + +func Unxz(archive string, path string) (err error) { i, err := os.Open(archive) if err != nil { return err } + defer i.Close() + r, err := xz.NewReader(i) if err != nil { - i.Close() return err } - o, err := os.Create(path) + + dir, base := filepath.Split(path) + if dir == "" { + dir = "." + } + o, err := os.CreateTemp(dir, base+".*.tmp") if err != nil { - i.Close() return err } - _, err = io.Copy(o, r) - i.Close() + tmpPath := o.Name() + defer func() { + if err != nil { + _ = os.Remove(tmpPath) + } + }() + defer func() { + if closeErr := o.Close(); err == nil { + err = closeErr + } + if err == nil { + err = os.Rename(tmpPath, path) + } + }() + + _, err = copyLimited(o, r, defaultUnxzFileLimit) return err }