pax_global_header00006660000000000000000000000064150561211140014506gustar00rootroot0000000000000052 comment=21f02191e81c5f5ff3e0e25ae06bcacb9c980451 golang-google-firebase-go-4.18.0/000077500000000000000000000000001505612111400165025ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/.github/000077500000000000000000000000001505612111400200425ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/.github/issue_template.md000066400000000000000000000022071505612111400234100ustar00rootroot00000000000000### [READ] Step 1: Are you in the right place? * For issues or feature requests related to __the code in this repository__ file a GitHub issue. * If this is a __feature request__ make sure the issue title starts with "FR:". * For general technical questions, post a question on [StackOverflow](http://stackoverflow.com/) with the firebase tag. * For general Firebase discussion, use the [firebase-talk](https://groups.google.com/forum/#!forum/firebase-talk) google group. * For help troubleshooting your application that does not fall under one of the above categories, reach out to the personalized [Firebase support channel](https://firebase.google.com/support/). ### [REQUIRED] Step 2: Describe your environment * Operating System version: _____ * Firebase SDK version: _____ * Library version: _____ * Firebase Product: _____ (auth, database, storage, etc) ### [REQUIRED] Step 3: Describe the problem #### Steps to reproduce: What happened? How can we make the problem occur? This could be a description, log/console output, etc. #### Relevant Code: ``` // TODO(you): code here to reproduce the problem ``` golang-google-firebase-go-4.18.0/.github/pull_request_template.md000066400000000000000000000013651505612111400250100ustar00rootroot00000000000000Hey there! So you want to contribute to a Firebase SDK? Before you file this pull request, please read these guidelines: ### Discussion * Read the contribution guidelines (CONTRIBUTING.md). * If this has been discussed in an issue, make sure to link to the issue here. If not, go file an issue about this **before creating a pull request** to discuss. ### Testing * Make sure all existing tests in the repository pass after your change. * If you fixed a bug or added a feature, add a new test to cover your code. ### API Changes * At this time we cannot accept changes that affect the public API. If you'd like to help us make Firebase APIs better, please propose your change in an issue so that we can discuss it together. golang-google-firebase-go-4.18.0/.github/resources/000077500000000000000000000000001505612111400220545ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/.github/resources/integ-service-account.json.gpg000066400000000000000000000033421505612111400277230ustar00rootroot00000000000000  7F_/Kdn9ӝAcf3Хf:g6fSk$#;sG8&DCͣNzGDCiamWabs)*Hg%LwrMeVid;bb?1W<5u3 qxixAM?КSȕxב=m]qSjB pbďr*؜y?fG;^l"[T6FU vqY]rWyl>>OQT?\V#YuO#lUYse׵-οU^c1-I*bmj/+J xqbbS| ýNPxЃ^j1`M ,E?A`*F#_@:]Yhwm"WRJg L wvTl;;̟}TH:;G)]0uOTBB]Yq_$ɽt^0Leћu+$L=QE<6~e){밦09 \?mDOvv̺\2ϊ6~נnK'G9Q T5\GkRCi{9wCʂ՘uun)_B42d@&;r[J#aUk;V%秮(><.z<*QLlxLF $[GqD$ I@ߨPJ0q?jLLaq|#?Κ 0+'(|wQqiyoLφ^B V+=tYY2IT~JЏ:WϨk5;!41SK V< jΝId P+l*υ@+B5Mj@y%x165iK.nn [ \[.`oyزT(?}'#R.rZC'mG`i Z7)K%RcLYhɲM* ~!˷:@AgD)^4weN>0%n7]t }{S&LH(^ã㍁ WKwBaeU%cm2u"Z?nq%]҈`AeG* P]hl E\c! 0*D^gcC\Sj07 =!IcE56+}z@_9f+Jes})9a{N2 UXlNxd*' ~Z8 Jk%0]~pj~dtjͯR5LOTW=qoV} -O~aQ)'lg1;񆟷^jD 2zcy׃%(<ՖezY2a\jHͫ1)Vɽ(OH~HЏԄA붘Y =̹敲W6t"d2g0dS^uM golang-google-firebase-go-4.18.0/.github/scripts/000077500000000000000000000000001505612111400215315ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/.github/scripts/generate_changelog.sh000077500000000000000000000043601505612111400256740ustar00rootroot00000000000000#!/bin/bash # Copyright 2020 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. set -e set -u function printChangelog() { local TITLE=$1 shift # Skip the sentinel value. local ENTRIES=("${@:2}") if [ ${#ENTRIES[@]} -ne 0 ]; then echo "### ${TITLE}" echo "" for ((i = 0; i < ${#ENTRIES[@]}; i++)) do echo "* ${ENTRIES[$i]}" done echo "" fi } if [[ -z "${GITHUB_SHA}" ]]; then GITHUB_SHA="HEAD" fi LAST_TAG=`git describe --tags $(git rev-list --tags --max-count=1) 2> /dev/null` || true if [[ -z "${LAST_TAG}" ]]; then echo "[INFO] No tags found. Including all commits up to ${GITHUB_SHA}." VERSION_RANGE="${GITHUB_SHA}" else echo "[INFO] Last release tag: ${LAST_TAG}." COMMIT_SHA=`git show-ref -s ${LAST_TAG}` echo "[INFO] Last release commit: ${COMMIT_SHA}." VERSION_RANGE="${COMMIT_SHA}..${GITHUB_SHA}" echo "[INFO] Including all commits in the range ${VERSION_RANGE}." fi echo "" # Older versions of Bash (< 4.4) treat empty arrays as unbound variables, which triggers # errors when referencing them. Therefore we initialize each of these arrays with an empty # sentinel value, and later skip them. CHANGES=("") FIXES=("") FEATS=("") MISC=("") while read -r line do COMMIT_MSG=`echo ${line} | cut -d ' ' -f 2-` if [[ $COMMIT_MSG =~ ^change(\(.*\))?: ]]; then CHANGES+=("$COMMIT_MSG") elif [[ $COMMIT_MSG =~ ^fix(\(.*\))?: ]]; then FIXES+=("$COMMIT_MSG") elif [[ $COMMIT_MSG =~ ^feat(\(.*\))?: ]]; then FEATS+=("$COMMIT_MSG") else MISC+=("${COMMIT_MSG}") fi done < <(git log ${VERSION_RANGE} --oneline) printChangelog "Breaking Changes" "${CHANGES[@]}" printChangelog "New Features" "${FEATS[@]}" printChangelog "Bug Fixes" "${FIXES[@]}" printChangelog "Miscellaneous" "${MISC[@]}" golang-google-firebase-go-4.18.0/.github/scripts/publish_preflight_check.sh000077500000000000000000000076431505612111400267510ustar00rootroot00000000000000#!/bin/bash # Copyright 2020 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. set -e set -u function echo_info() { local MESSAGE=$1 echo "[INFO] ${MESSAGE}" } function echo_warn() { local MESSAGE=$1 echo "[WARN] ${MESSAGE}" } function terminate() { echo "" echo_warn "--------------------------------------------" echo_warn "PREFLIGHT FAILED" echo_warn "--------------------------------------------" exit 1 } echo_info "Starting publish preflight check..." echo_info "Git revision : ${GITHUB_SHA}" echo_info "Git ref : ${GITHUB_REF}" echo_info "Workflow triggered by : ${GITHUB_ACTOR}" echo_info "GitHub event : ${GITHUB_EVENT_NAME}" echo_info "" echo_info "--------------------------------------------" echo_info "Extracting release version" echo_info "--------------------------------------------" echo_info "" echo_info "Loading version from: firebase.go" readonly RELEASE_VERSION=`grep "const Version" firebase.go | awk '{print $4}' | tr -d \"` || true if [[ -z "${RELEASE_VERSION}" ]]; then echo_warn "Failed to extract release version from: firebase.go" terminate fi if [[ ! "${RELEASE_VERSION}" =~ ^([0-9]*)\.([0-9]*)\.([0-9]*)$ ]]; then echo_warn "Malformed release version string: ${RELEASE_VERSION}. Exiting." terminate fi echo_info "Extracted release version: ${RELEASE_VERSION}" echo "version=v${RELEASE_VERSION}" >> $GITHUB_OUTPUT echo_info "" echo_info "--------------------------------------------" echo_info "Checking release tag" echo_info "--------------------------------------------" echo_info "" echo_info "---< git fetch --depth=1 origin +refs/tags/*:refs/tags/* >---" git fetch --depth=1 origin +refs/tags/*:refs/tags/* echo "" readonly EXISTING_TAG=`git rev-parse -q --verify "refs/tags/v${RELEASE_VERSION}"` || true if [[ -n "${EXISTING_TAG}" ]]; then echo_warn "Tag v${RELEASE_VERSION} already exists. Exiting." echo_warn "If the tag was created in a previous unsuccessful attempt, delete it and try again." echo_warn " $ git tag -d v${RELEASE_VERSION}" echo_warn " $ git push --delete origin v${RELEASE_VERSION}" readonly RELEASE_URL="https://github.com/firebase/firebase-admin-go/releases/tag/v${RELEASE_VERSION}" echo_warn "Delete any corresponding releases at ${RELEASE_URL}." terminate fi echo_info "Tag v${RELEASE_VERSION} does not exist." echo_info "" echo_info "--------------------------------------------" echo_info "Generating changelog" echo_info "--------------------------------------------" echo_info "" echo_info "---< git fetch origin dev --prune --unshallow >---" git fetch origin dev --prune --unshallow echo "" echo_info "Generating changelog from history..." readonly CURRENT_DIR=$(dirname "$0") readonly CHANGELOG=`${CURRENT_DIR}/generate_changelog.sh` echo "$CHANGELOG" # Parse and preformat the text to handle multi-line output. # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#example-of-a-multiline-string # and https://github.com/github/docs/issues/21529#issue-1418590935 FILTERED_CHANGELOG=`echo "$CHANGELOG" | grep -v "\\[INFO\\]"` FILTERED_CHANGELOG="${FILTERED_CHANGELOG//$'\''/'"'}" echo "changelog<> $GITHUB_OUTPUT echo -e "$FILTERED_CHANGELOG" >> $GITHUB_OUTPUT echo "CHANGELOGEOF" >> $GITHUB_OUTPUT echo "" echo_info "--------------------------------------------" echo_info "PREFLIGHT SUCCESSFUL" echo_info "--------------------------------------------" golang-google-firebase-go-4.18.0/.github/scripts/run_all_tests.sh000077500000000000000000000015371505612111400247540ustar00rootroot00000000000000#!/bin/bash # Copyright 2020 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. set -e set -u gpg --quiet --batch --yes --decrypt --passphrase="${FIREBASE_SERVICE_ACCT_KEY}" \ --output testdata/integration_cert.json .github/resources/integ-service-account.json.gpg echo "${FIREBASE_API_KEY}" > testdata/integration_apikey.txt go test -v -race ./... golang-google-firebase-go-4.18.0/.github/workflows/000077500000000000000000000000001505612111400220775ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/.github/workflows/ci.yml000066400000000000000000000015731505612111400232230ustar00rootroot00000000000000name: Continuous Integration on: pull_request jobs: module: name: Module build runs-on: ubuntu-latest strategy: fail-fast: false matrix: go: ['1.23', '1.24'] steps: - name: Check out code uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Install golint run: go install golang.org/x/lint/golint@latest - name: Run Linter run: | golint -set_exit_status ./... - name: Run Unit Tests if: success() || failure() run: go test -v -race -test.short ./... - name: Run Formatter run: | if [[ ! -z "$(gofmt -l -s .)" ]]; then echo "Go code is not formatted:" gofmt -d -s . exit 1 fi - name: Run Static Analyzer run: go vet -v ./... golang-google-firebase-go-4.18.0/.github/workflows/nightly.yml000066400000000000000000000057461505612111400243140ustar00rootroot00000000000000# Copyright 2021 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Nightly Builds on: # Runs every day at 06:30 AM (PT) and 08:30 PM (PT) / 04:30 AM (UTC) and 02:30 PM (UTC) # or on 'firebase_nightly_build' repository dispatch event. schedule: - cron: "30 4,14 * * *" repository_dispatch: types: [firebase_nightly_build] jobs: nightly: runs-on: ubuntu-latest steps: - name: Check out code uses: actions/checkout@v4 with: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Go uses: actions/setup-go@v5 with: go-version: '1.23' - name: Install golint run: go install golang.org/x/lint/golint@latest - name: Run Linter run: | golint -set_exit_status ./... - name: Run Tests run: ./.github/scripts/run_all_tests.sh env: FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} - name: Send email on failure if: failure() uses: firebase/firebase-admin-node/.github/actions/send-email@master with: api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} from: 'GitHub ' to: ${{ secrets.FIREBASE_ADMIN_GITHUB_EMAIL }} subject: 'Nightly build ${{github.run_id}} of ${{github.repository}} failed!' html: > Nightly workflow ${{github.run_id}} failed on: ${{github.repository}}

Navigate to the failed workflow. continue-on-error: true - name: Send email on cancelled if: cancelled() uses: firebase/firebase-admin-node/.github/actions/send-email@master with: api-key: ${{ secrets.OSS_BOT_MAILGUN_KEY }} domain: ${{ secrets.OSS_BOT_MAILGUN_DOMAIN }} from: 'GitHub ' to: ${{ secrets.FIREBASE_ADMIN_GITHUB_EMAIL }} subject: 'Nightly build ${{github.run_id}} of ${{github.repository}} cancelled!' html: > Nightly workflow ${{github.run_id}} cancelled on: ${{github.repository}}

Navigate to the cancelled workflow. continue-on-error: true golang-google-firebase-go-4.18.0/.github/workflows/release.yml000066400000000000000000000107701505612111400242470ustar00rootroot00000000000000# Copyright 2020 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Release Candidate on: # Only run the workflow when a PR is updated or when a developer explicitly requests # a build by sending a 'firebase_build' event. pull_request: types: [opened, synchronize, closed] repository_dispatch: types: - firebase_build jobs: stage_release: # To publish a release, merge the release PR with the label 'release:publish'. # To stage a release without publishing it, send a 'firebase_build' event or apply # the 'release:stage' label to a PR. if: github.event.action == 'firebase_build' || contains(github.event.pull_request.labels.*.name, 'release:stage') || (github.event.pull_request.merged && contains(github.event.pull_request.labels.*.name, 'release:publish')) runs-on: ubuntu-latest # When manually triggering the build, the requester can specify a target branch or a tag # via the 'ref' client parameter. steps: - name: Check out code uses: actions/checkout@v4 with: ref: ${{ github.event.client_payload.ref || github.ref }} - name: Set up Go uses: actions/setup-go@v5 with: go-version: '1.23' - name: Install golint run: go install golang.org/x/lint/golint@latest - name: Run Linter run: | golint -set_exit_status ./... - name: Run Tests run: ./.github/scripts/run_all_tests.sh env: FIREBASE_SERVICE_ACCT_KEY: ${{ secrets.FIREBASE_SERVICE_ACCT_KEY }} FIREBASE_API_KEY: ${{ secrets.FIREBASE_API_KEY }} publish_release: needs: stage_release # Check whether the release should be published. We publish only when the trigger PR is # 1. merged # 2. to the dev branch # 3. with the label 'release:publish', and # 4. the title prefix '[chore] Release '. if: github.event.pull_request.merged && github.ref == 'refs/heads/dev' && contains(github.event.pull_request.labels.*.name, 'release:publish') && startsWith(github.event.pull_request.title, '[chore] Release ') runs-on: ubuntu-latest permissions: contents: write steps: - name: Checkout source for publish uses: actions/checkout@v4 with: persist-credentials: false - name: Publish preflight check id: preflight run: ./.github/scripts/publish_preflight_check.sh # We authorize this step with an access token that has write access to the master branch. - name: Merge to master uses: actions/github-script@v7 with: github-token: ${{ secrets.FIREBASE_GITHUB_TOKEN }} script: | github.rest.repos.merge({ owner: context.repo.owner, repo: context.repo.repo, base: 'master', head: 'dev' }) # See: https://cli.github.com/manual/gh_release_create - name: Create release tag env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh release create ${{ steps.preflight.outputs.version }} --title "Firebase Admin Go SDK ${{ steps.preflight.outputs.version }}" --notes '${{ steps.preflight.outputs.changelog }}' --target "master" # Post to Twitter if explicitly opted-in by adding the label 'release:tweet'. - name: Post to Twitter if: success() && contains(github.event.pull_request.labels.*.name, 'release:tweet') uses: firebase/firebase-admin-node/.github/actions/send-tweet@master with: status: > ${{ steps.preflight.outputs.version }} of @Firebase Admin Go SDK is available. https://github.com/firebase/firebase-admin-go/releases/tag/${{ steps.preflight.outputs.version }} consumer-key: ${{ secrets.FIREBASE_TWITTER_CONSUMER_KEY }} consumer-secret: ${{ secrets.FIREBASE_TWITTER_CONSUMER_SECRET }} access-token: ${{ secrets.FIREBASE_TWITTER_ACCESS_TOKEN }} access-token-secret: ${{ secrets.FIREBASE_TWITTER_ACCESS_TOKEN_SECRET }} continue-on-error: true golang-google-firebase-go-4.18.0/.gitignore000066400000000000000000000000641505612111400204720ustar00rootroot00000000000000testdata/integration_* .vscode/* *~ \#*\# .DS_Store golang-google-firebase-go-4.18.0/.opensource/000077500000000000000000000000001505612111400207425ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/.opensource/project.json000066400000000000000000000004401505612111400233010ustar00rootroot00000000000000{ "name": "Firebase Admin SDK - Go", "platforms": [ "Go", "Admin" ], "content": "README.md", "pages": [], "related": [ "firebase/firebase-admin-java", "firebase/firebase-admin-node", "firebase/firebase-admin-python" ] } golang-google-firebase-go-4.18.0/CONTRIBUTING.md000066400000000000000000000222651505612111400207420ustar00rootroot00000000000000# Contributing | Firebase Admin Go SDK Thank you for contributing to the Firebase community! - [Have a usage question?](#question) - [Think you found a bug?](#issue) - [Have a feature request?](#feature) - [Want to submit a pull request?](#submit) - [Need to get set up locally?](#local-setup) ## Have a usage question? We get lots of those and we love helping you, but GitHub is not the best place for them. Issues which just ask about usage will be closed. Here are some resources to get help: - Go through the [guides](https://firebase.google.com/docs/admin/setup/) - Read the full [API reference](https://godoc.org/firebase.google.com/go) If the official documentation doesn't help, try asking a question on the [Firebase Google Group](https://groups.google.com/forum/#!forum/firebase-talk/) or one of our other [official support channels](https://firebase.google.com/support/). **Please avoid double posting across multiple channels!** ## Think you found a bug? Yeah, we're definitely not perfect! Search through [old issues](https://github.com/firebase/firebase-admin-go/issues) before submitting a new issue as your question may have already been answered. If your issue appears to be a bug, and hasn't been reported, [open a new issue](https://github.com/firebase/firebase-admin-go/issues/new). Please use the provided bug report template and include a minimal repro. If you are up to the challenge, [submit a pull request](#submit) with a fix! ## Have a feature request? Great, we love hearing how we can improve our products! Share you idea through our [feature request support channel](https://firebase.google.com/support/contact/bugs-features/). ## Want to submit a pull request? Sweet, we'd love to accept your contribution! [Open a new pull request](https://github.com/firebase/firebase-admin-go/pull/new/master) and fill out the provided template. Make sure to create all your pull requests against the `dev` branch. All development work takes place on this branch, while the `master` branch is dedicated for released stable code. This enables us to review and merge routine code changes, without impacting downstream applications that are building against our `master` branch. **If you want to implement a new feature, please open an issue with a proposal first so that we can figure out if the feature makes sense and how it will work.** Make sure your changes pass our linter and the tests all pass on your local machine. Most non-trivial changes should include some extra test coverage. If you aren't sure how to add tests, feel free to submit regardless and ask us for some advice. Finally, you will need to sign our [Contributor License Agreement](https://cla.developers.google.com/about/google-individual), and go through our code review process before we can accept your pull request. ### Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution. This simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ### Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Need to get set up locally? ### Initial Setup Use the standard GitHub and [Go development tools](https://golang.org/doc/cmd) to build and test the Firebase Admin SDK. Follow the instructions given in the [golang documentation](https://golang.org/doc/code.html) to get your `GOPATH` set up correctly. Then execute the following series of commands to checkout the sources of Firebase Admin SDK, and its dependencies: ```bash $ cd $GOPATH $ git clone https://github.com/firebase/firebase-admin-go.git src/firebase.google.com/go $ go get -d -t firebase.google.com/go/... # Install dependencies ``` ### Unit Testing Invoke the `go test` command as follows to build and run the unit tests: ```bash go test -test.short firebase.google.com/go/... ``` Note the `-test.short` flag passed to the `go test` command. This will skip the integration tests, and only execute the unit tests. ### Integration Testing Integration tests are executed against a real life Firebase project. If you do not already have one suitable for running the tests against, you can create a new project in the [Firebase Console](https://console.firebase.google.com) following the setup guide below. If you already have a Firebase project, you'll need to obtain credentials to communicate and authorize access to your Firebase project: 1. Service account certificate: This allows access to your Firebase project through a service account which is required for all integration tests. This can be downloaded as a JSON file from the **Settings > Service Accounts** tab of the Firebase console when you click the **Generate new private key** button. Copy the file into the repo so it's available at `src/firebase.google.com/go/testdata/integration_cert.json`. > **Note:** Service accounts should be carefully managed and their keys should never be stored in publicly accessible source code or repositories. 2. Web API key: This allows for Auth sign-in needed for some Authentication and Tenant Management integration tests. This is displayed in the **Settings > General** tab of the Firebase console after enabling Authentication as described in the steps below. Copy it and save to a new text file at `src/firebase.google.com/go/testdata/integration_apikey.txt`. Set up your Firebase project as follows: 1. Enable Authentication: 1. Go to the Firebase Console, and select **Authentication** from the **Build** menu. 2. Click on **Get Started**. 3. Select **Sign-in method > Add new provider > Email/Password** then enable both the **Email/Password** and **Email link (passwordless sign-in)** options. 2. Enable Firestore: 1. Go to the Firebase Console, and select **Firestore Database** from the **Build** menu. 2. Click on the **Create database** button. You can choose to set up Firestore either in the production mode or in the test mode. 3. Enable Realtime Database: 1. Go to the Firebase Console, and select **Realtime Database** from the **Build** menu. 2. Click on the **Create Database** button. You can choose to set up the Realtime Database either in the locked mode or in the test mode. > **Note:** Integration tests are not run against the default Realtime Database reference and are instead run against a database created at `https://{PROJECT_ID}.firebaseio.com`. This second Realtime Database reference is created in the following steps. 3. In the **Data** tab click on the kebab menu (3 dots) and select **Create Database**. 4. Enter your Project ID (Found in the **General** tab in **Account Settings**) as the **Realtime Database reference**. Again, you can choose to set up the Realtime Database either in the locked mode or in the test mode. 4. Enable Storage: 1. Go to the Firebase Console, and select **Storage** from the **Build** menu. 2. Click on the **Get started** button. You can choose to set up Cloud Storage either in the production mode or in the test mode. 5. Enable the IAM API: 1. Go to the [Google Cloud console](https://console.cloud.google.com) and make sure your Firebase project is selected. 2. Select **APIs & Services** from the main menu, and click the **ENABLE APIS AND SERVICES** button. 3. Search for and enable **Identity and Access Management (IAM) API** by Google Enterprise API. 6. Enable Tenant Management: 1. Go to [Google Cloud console | Identity Platform](https://console.cloud.google.com/customer-identity/) and if it is not already enabled, click **Enable**. 2. Then [enable multi-tenancy](https://cloud.google.com/identity-platform/docs/multi-tenancy-quickstart#enabling_multi-tenancy) for your project. 7. Ensure your service account has the **Firebase Authentication Admin** role. This is required to ensure that exported user records contain the password hashes of the user accounts: 1. Go to [Google Cloud console | IAM & admin](https://console.cloud.google.com/iam-admin). 2. Find your service account in the list. If not added click the pencil icon to edit its permissions. 3. Click **ADD ANOTHER ROLE** and choose **Firebase Authentication Admin**. 4. Click **SAVE**. Now you can invoke the test suite as follows: ```bash go test firebase.google.com/go/... ``` This will execute both unit and integration test suites. ### Test Coverage Coverage can be measured per package by passing the `-cover` flag to the test invocation: ```bash go test -cover firebase.google.com/go/auth ``` To view the detailed coverage reports (per package): ```bash go test -cover -coverprofile=coverage.out firebase.google.com/go go tool cover -html=coverage.out ``` golang-google-firebase-go-4.18.0/LICENSE000066400000000000000000000261351505612111400175160ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. golang-google-firebase-go-4.18.0/README.md000066400000000000000000000051141505612111400177620ustar00rootroot00000000000000[![Build Status](https://github.com/firebase/firebase-admin-go/workflows/Continuous%20Integration/badge.svg?branch=dev)](https://github.com/firebase/firebase-admin-go/actions) [![GoDoc](https://godoc.org/firebase.google.com/go?status.svg)](https://godoc.org/firebase.google.com/go) [![Go Report Card](https://goreportcard.com/badge/github.com/firebase/firebase-admin-go)](https://goreportcard.com/report/github.com/firebase/firebase-admin-go) # Firebase Admin Go SDK ## Table of Contents * [Overview](#overview) * [Installation](#installation) * [Contributing](#contributing) * [Documentation](#documentation) * [License and Terms](#license-and-terms) ## Overview [Firebase](https://firebase.google.com) provides the tools and infrastructure you need to develop apps, grow your user base, and earn money. The Firebase Admin Go SDK enables access to Firebase services from privileged environments (such as servers or cloud) in Go. Currently this SDK provides Firebase custom authentication support. For more information, visit the [Firebase Admin SDK setup guide](https://firebase.google.com/docs/admin/setup/). ## Installation The Firebase Admin Go SDK can be installed using the `go get` utility: ``` # Install the latest version: go get firebase.google.com/go/v4@latest # Or install a specific version: go get firebase.google.com/go/v4@4.x.x ``` ## Contributing Please refer to the [CONTRIBUTING page](./CONTRIBUTING.md) for more information about how you can contribute to this project. We welcome bug reports, feature requests, code review feedback, and also pull requests. ## Supported Go Versions The Admin Go SDK is compatible with the two most-recent major Go releases. We currently support Go v1.23 and 1.24. [Continuous integration](https://github.com/firebase/firebase-admin-go/actions) system tests the code on Go v1.23 and v1.24. ## Documentation * [Setup Guide](https://firebase.google.com/docs/admin/setup/) * [Authentication Guide](https://firebase.google.com/docs/auth/admin/) * [Cloud Firestore](https://firebase.google.com/docs/firestore/) * [Cloud Messaging Guide](https://firebase.google.com/docs/cloud-messaging/admin/) * [Storage Guide](https://firebase.google.com/docs/storage/admin/start) * [API Reference](https://godoc.org/firebase.google.com/go) * [Release Notes](https://firebase.google.com/support/release-notes/admin/go) ## License and Terms Firebase Admin Go SDK is licensed under the [Apache License, version 2.0](http://www.apache.org/licenses/LICENSE-2.0). Your use of Firebase is governed by the [Terms of Service for Firebase Services](https://firebase.google.com/terms/). golang-google-firebase-go-4.18.0/appcheck/000077500000000000000000000000001505612111400202605ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/appcheck/appcheck.go000066400000000000000000000136021505612111400223670ustar00rootroot00000000000000// Copyright 2022 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package appcheck provides functionality for verifying App Check tokens. package appcheck import ( "context" "errors" "strings" "time" "github.com/MicahParks/keyfunc" "github.com/golang-jwt/jwt/v4" "firebase.google.com/go/v4/internal" ) // JWKSUrl is the URL of the JWKS used to verify App Check tokens. var JWKSUrl = "https://firebaseappcheck.googleapis.com/v1beta/jwks" const appCheckIssuer = "https://firebaseappcheck.googleapis.com/" var ( // ErrIncorrectAlgorithm is returned when the token is signed with a non-RSA256 algorithm. ErrIncorrectAlgorithm = errors.New("token has incorrect algorithm") // ErrTokenType is returned when the token is not a JWT. ErrTokenType = errors.New("token has incorrect type") // ErrTokenClaims is returned when the token claims cannot be decoded. ErrTokenClaims = errors.New("token has incorrect claims") // ErrTokenAudience is returned when the token audience does not match the current project. ErrTokenAudience = errors.New("token has incorrect audience") // ErrTokenIssuer is returned when the token issuer does not match Firebase's App Check service. ErrTokenIssuer = errors.New("token has incorrect issuer") // ErrTokenSubject is returned when the token subject is empty or missing. ErrTokenSubject = errors.New("token has empty or missing subject") ) // DecodedAppCheckToken represents a verified App Check token. // // DecodedAppCheckToken provides typed accessors to the common JWT fields such as Audience (aud) // and ExpiresAt (exp). Additionally it provides an AppID field, which indicates the application ID to which this // token belongs. Any additional JWT claims can be accessed via the Claims map of DecodedAppCheckToken. type DecodedAppCheckToken struct { Issuer string Subject string Audience []string ExpiresAt time.Time IssuedAt time.Time AppID string Claims map[string]interface{} } // Client is the interface for the Firebase App Check service. type Client struct { projectID string jwks *keyfunc.JWKS } // NewClient creates a new instance of the Firebase App Check Client. // // This function can only be invoked from within the SDK. Client applications should access the // the App Check service through firebase.App. func NewClient(ctx context.Context, conf *internal.AppCheckConfig) (*Client, error) { // TODO: Add support for overriding the HTTP client using the App one. jwks, err := keyfunc.Get(JWKSUrl, keyfunc.Options{ Ctx: ctx, RefreshInterval: 6 * time.Hour, }) if err != nil { return nil, err } return &Client{ projectID: conf.ProjectID, jwks: jwks, }, nil } // VerifyToken verifies the given App Check token. // // VerifyToken considers an App Check token string to be valid if all the following conditions are met: // - The token string is a valid RS256 JWT. // - The JWT contains valid issuer (iss) and audience (aud) claims that match the issuerPrefix // and projectID of the tokenVerifier. // - The JWT contains a valid subject (sub) claim. // - The JWT is not expired, and it has been issued some time in the past. // - The JWT is signed by a Firebase App Check backend server as determined by the keySource. // // If any of the above conditions are not met, an error is returned. Otherwise a pointer to a // decoded App Check token is returned. func (c *Client) VerifyToken(token string) (*DecodedAppCheckToken, error) { // References for checks: // https://firebase.googleblog.com/2021/10/protecting-backends-with-app-check.html // https://github.com/firebase/firebase-admin-node/blob/master/src/app-check/token-verifier.ts#L106 // The standard JWT parser also validates the expiration of the token // so we do not need dedicated code for that. decodedToken, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { if t.Header["alg"] != "RS256" { return nil, ErrIncorrectAlgorithm } if t.Header["typ"] != "JWT" { return nil, ErrTokenType } return c.jwks.Keyfunc(t) }) if err != nil { return nil, err } claims, ok := decodedToken.Claims.(jwt.MapClaims) if !ok { return nil, ErrTokenClaims } rawAud := claims["aud"].([]interface{}) aud := []string{} for _, v := range rawAud { aud = append(aud, v.(string)) } if !contains(aud, "projects/"+c.projectID) { return nil, ErrTokenAudience } // We check the prefix to make sure this token was issued // by the Firebase App Check service, but we do not check the // Project Number suffix because the Golang SDK only has project ID. // // This is consistent with the Firebase Admin Node SDK. if !strings.HasPrefix(claims["iss"].(string), appCheckIssuer) { return nil, ErrTokenIssuer } if val, ok := claims["sub"].(string); !ok || val == "" { return nil, ErrTokenSubject } appCheckToken := DecodedAppCheckToken{ Issuer: claims["iss"].(string), Subject: claims["sub"].(string), Audience: aud, ExpiresAt: time.Unix(int64(claims["exp"].(float64)), 0), IssuedAt: time.Unix(int64(claims["iat"].(float64)), 0), AppID: claims["sub"].(string), } // Remove all the claims we've already parsed. for _, usedClaim := range []string{"iss", "sub", "aud", "exp", "iat", "sub"} { delete(claims, usedClaim) } appCheckToken.Claims = claims return &appCheckToken, nil } func contains(s []string, str string) bool { for _, v := range s { if v == str { return true } } return false } golang-google-firebase-go-4.18.0/appcheck/appcheck_test.go000066400000000000000000000170251505612111400234310ustar00rootroot00000000000000package appcheck import ( "context" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "net/http" "net/http/httptest" "os" "testing" "time" "firebase.google.com/go/v4/internal" "github.com/golang-jwt/jwt/v4" "github.com/google/go-cmp/cmp" ) func TestVerifyTokenHasValidClaims(t *testing.T) { ts, err := setupFakeJWKS() if err != nil { t.Fatalf("Error setting up fake JWKS server: %v", err) } defer ts.Close() privateKey, err := loadPrivateKey() if err != nil { t.Fatalf("Error loading private key: %v", err) } JWKSUrl = ts.URL conf := &internal.AppCheckConfig{ ProjectID: "project_id", } client, err := NewClient(context.Background(), conf) if err != nil { t.Errorf("Error creating NewClient: %v", err) } type appCheckClaims struct { Aud []string `json:"aud"` jwt.RegisteredClaims } mockTime := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC) jwt.TimeFunc = func() time.Time { return mockTime } tokenTests := []struct { claims *appCheckClaims wantErr error wantToken *DecodedAppCheckToken }{ { &appCheckClaims{ []string{"projects/12345678", "projects/project_id"}, jwt.RegisteredClaims{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", ExpiresAt: jwt.NewNumericDate(mockTime.Add(time.Hour)), IssuedAt: jwt.NewNumericDate(mockTime), }}, nil, &DecodedAppCheckToken{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", Audience: []string{"projects/12345678", "projects/project_id"}, ExpiresAt: mockTime.Add(time.Hour), IssuedAt: mockTime, AppID: "12345678:app:ID", Claims: map[string]interface{}{}, }, }, { &appCheckClaims{ []string{"projects/12345678", "projects/project_id"}, jwt.RegisteredClaims{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", ExpiresAt: jwt.NewNumericDate(mockTime.Add(time.Hour)), IssuedAt: jwt.NewNumericDate(mockTime), // A field our AppCheckToken does not use. NotBefore: jwt.NewNumericDate(mockTime.Add(-1 * time.Hour)), }}, nil, &DecodedAppCheckToken{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", Audience: []string{"projects/12345678", "projects/project_id"}, ExpiresAt: mockTime.Add(time.Hour), IssuedAt: mockTime, AppID: "12345678:app:ID", Claims: map[string]interface{}{ "nbf": float64(mockTime.Add(-1 * time.Hour).Unix()), }, }, }, { &appCheckClaims{ []string{"projects/0000000", "projects/another_project_id"}, jwt.RegisteredClaims{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", ExpiresAt: jwt.NewNumericDate(mockTime.Add(time.Hour)), IssuedAt: jwt.NewNumericDate(mockTime), }}, ErrTokenAudience, nil, }, { &appCheckClaims{ []string{"projects/12345678", "projects/project_id"}, jwt.RegisteredClaims{ Issuer: "https://not-firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", ExpiresAt: jwt.NewNumericDate(mockTime.Add(time.Hour)), IssuedAt: jwt.NewNumericDate(mockTime), }}, ErrTokenIssuer, nil, }, { &appCheckClaims{ []string{"projects/12345678", "projects/project_id"}, jwt.RegisteredClaims{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "", ExpiresAt: jwt.NewNumericDate(mockTime.Add(time.Hour)), IssuedAt: jwt.NewNumericDate(mockTime), }}, ErrTokenSubject, nil, }, { &appCheckClaims{ []string{"projects/12345678", "projects/project_id"}, jwt.RegisteredClaims{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", ExpiresAt: jwt.NewNumericDate(mockTime.Add(time.Hour)), IssuedAt: jwt.NewNumericDate(mockTime), }}, ErrTokenSubject, nil, }, } for _, tc := range tokenTests { // Create an App Check-style token. jwtToken := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.claims) // kid matches the key ID in testdata/mock.jwks.json, // which is the public key matching to the private key // in testdata/appcheck_pk.pem. jwtToken.Header["kid"] = "FGQdnRlzAmKyKr6-Hg_kMQrBkj_H6i6ADnBQz4OI6BU" token, err := jwtToken.SignedString(privateKey) if err != nil { t.Fatalf("error generating JWT: %v", err) } // Verify the token. gotToken, gotErr := client.VerifyToken(token) if !errors.Is(gotErr, tc.wantErr) { t.Errorf("Expected error %v, got %v", tc.wantErr, gotErr) continue } if diff := cmp.Diff(tc.wantToken, gotToken); diff != "" { t.Errorf("VerifyToken mismatch (-want +got):\n%s", diff) } } } func TestVerifyTokenMustExist(t *testing.T) { ts, err := setupFakeJWKS() if err != nil { t.Fatalf("Error setting up fake JWK server: %v", err) } defer ts.Close() JWKSUrl = ts.URL conf := &internal.AppCheckConfig{ ProjectID: "project_id", } client, err := NewClient(context.Background(), conf) if err != nil { t.Errorf("Error creating NewClient: %v", err) } for _, token := range []string{"", "-", "."} { gotToken, gotErr := client.VerifyToken(token) if gotErr == nil { t.Errorf("VerifyToken(%s) expected error, got nil", token) } if gotToken != nil { t.Errorf("Expected nil, got token %v", gotToken) } } } func TestVerifyTokenNotExpired(t *testing.T) { ts, err := setupFakeJWKS() if err != nil { t.Fatalf("Error setting up fake JWKS server: %v", err) } defer ts.Close() privateKey, err := loadPrivateKey() if err != nil { t.Fatalf("Error loading private key: %v", err) } JWKSUrl = ts.URL conf := &internal.AppCheckConfig{ ProjectID: "project_id", } client, err := NewClient(context.Background(), conf) if err != nil { t.Errorf("Error creating NewClient: %v", err) } mockTime := time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC) jwt.TimeFunc = func() time.Time { return mockTime } tokenTests := []struct { expiresAt time.Time wantErr bool }{ // Expire in the future is OK. {mockTime.Add(time.Hour), false}, // Expire in the past is not OK. {mockTime.Add(-1 * time.Hour), true}, } for _, tc := range tokenTests { claims := struct { Aud []string `json:"aud"` jwt.RegisteredClaims }{ []string{"projects/12345678", "projects/project_id"}, jwt.RegisteredClaims{ Issuer: "https://firebaseappcheck.googleapis.com/12345678", Subject: "12345678:app:ID", ExpiresAt: jwt.NewNumericDate(tc.expiresAt), IssuedAt: jwt.NewNumericDate(mockTime), }, } jwtToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) jwtToken.Header["kid"] = "FGQdnRlzAmKyKr6-Hg_kMQrBkj_H6i6ADnBQz4OI6BU" token, err := jwtToken.SignedString(privateKey) if err != nil { t.Fatalf("error generating JWT: %v", err) } _, gotErr := client.VerifyToken(token) if tc.wantErr && gotErr == nil { t.Errorf("Expected an error, got none") } else if !tc.wantErr && gotErr != nil { t.Errorf("Expected no error, got %v", gotErr) } } } func setupFakeJWKS() (*httptest.Server, error) { jwks, err := os.ReadFile("../testdata/mock.jwks.json") if err != nil { return nil, err } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write(jwks) })) return ts, nil } func loadPrivateKey() (*rsa.PrivateKey, error) { pk, err := os.ReadFile("../testdata/appcheck_pk.pem") if err != nil { return nil, err } block, _ := pem.Decode(pk) privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, err } return privateKey, nil } golang-google-firebase-go-4.18.0/auth/000077500000000000000000000000001505612111400174435ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/auth/auth.go000066400000000000000000000410461505612111400207400ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package auth contains functions for minting custom authentication tokens, verifying Firebase ID tokens, // and managing users in a Firebase project. package auth import ( "context" "errors" "fmt" "os" "strings" "time" "firebase.google.com/go/v4/internal" "golang.org/x/oauth2" "google.golang.org/api/option" "google.golang.org/api/transport" ) const ( authErrorCode = "authErrorCode" emulatorHostEnvVar = "FIREBASE_AUTH_EMULATOR_HOST" defaultAuthURL = "https://identitytoolkit.googleapis.com" firebaseAudience = "https://identitytoolkit.googleapis.com/google.identity.identitytoolkit.v1.IdentityToolkit" oneHourInSeconds = 3600 // SDK-generated error codes idTokenRevoked = "ID_TOKEN_REVOKED" userDisabled = "USER_DISABLED" sessionCookieRevoked = "SESSION_COOKIE_REVOKED" tenantIDMismatch = "TENANT_ID_MISMATCH" ) var reservedClaims = []string{ "acr", "amr", "at_hash", "aud", "auth_time", "azp", "cnf", "c_hash", "exp", "firebase", "iat", "iss", "jti", "nbf", "nonce", "sub", } var emulatorToken = &oauth2.Token{ AccessToken: "owner", } // Client is the interface for the Firebase auth service. // // Client facilitates generating custom JWT tokens for Firebase clients, and verifying ID tokens issued // by Firebase backend services. type Client struct { *baseClient TenantManager *TenantManager } // NewClient creates a new instance of the Firebase Auth Client. // // This function can only be invoked from within the SDK. Client applications should access the // Auth service through firebase.App. func NewClient(ctx context.Context, conf *internal.AuthConfig) (*Client, error) { var ( isEmulator bool signer cryptoSigner err error ) authEmulatorHost := os.Getenv(emulatorHostEnvVar) if authEmulatorHost != "" { isEmulator = true signer = emulatedSigner{} } if signer == nil { creds, _ := transport.Creds(ctx, conf.Opts...) // Initialize a signer by following the go/firebase-admin-sign protocol. if creds != nil && len(creds.JSON) > 0 { // If the SDK was initialized with a service account, use it to sign bytes. signer, err = signerFromCreds(creds.JSON) if err != nil && err != errNotAServiceAcct { return nil, err } } } if signer == nil { if conf.ServiceAccountID != "" { // If the SDK was initialized with a service account email, use it with the IAM service // to sign bytes. signer, err = newIAMSigner(ctx, conf) if err != nil { return nil, err } } else { // Use GAE signing capabilities if available. Otherwise, obtain a service account email // from the local Metadata service, and fallback to the IAM service. signer, err = newCryptoSigner(ctx, conf) if err != nil { return nil, err } } } idTokenVerifier, err := newIDTokenVerifier(ctx, conf.ProjectID) if err != nil { return nil, err } cookieVerifier, err := newSessionCookieVerifier(ctx, conf.ProjectID) if err != nil { return nil, err } var opts []option.ClientOption if isEmulator { ts := oauth2.StaticTokenSource(emulatorToken) opts = append(opts, option.WithTokenSource(ts)) } else { opts = append(opts, conf.Opts...) } transport, _, err := transport.NewHTTPClient(ctx, opts...) if err != nil { return nil, err } hc := internal.WithDefaultRetryConfig(transport) hc.CreateErrFn = handleHTTPError hc.Opts = []internal.HTTPOption{ internal.WithHeader("X-Client-Version", fmt.Sprintf("Go/Admin/%s", conf.Version)), internal.WithHeader("x-goog-api-client", internal.GetMetricsHeader(conf.Version)), } baseURL := defaultAuthURL if isEmulator { baseURL = fmt.Sprintf("http://%s/identitytoolkit.googleapis.com", authEmulatorHost) } idToolkitV1Endpoint := fmt.Sprintf("%s/v1", baseURL) idToolkitV2Endpoint := fmt.Sprintf("%s/v2", baseURL) userManagementEndpoint := idToolkitV1Endpoint providerConfigEndpoint := idToolkitV2Endpoint tenantMgtEndpoint := idToolkitV2Endpoint projectMgtEndpoint := idToolkitV2Endpoint base := &baseClient{ userManagementEndpoint: userManagementEndpoint, providerConfigEndpoint: providerConfigEndpoint, tenantMgtEndpoint: tenantMgtEndpoint, projectMgtEndpoint: projectMgtEndpoint, projectID: conf.ProjectID, httpClient: hc, idTokenVerifier: idTokenVerifier, cookieVerifier: cookieVerifier, signer: signer, clock: internal.SystemClock, isEmulator: isEmulator, } return &Client{ baseClient: base, TenantManager: newTenantManager(hc, conf, base), }, nil } // CustomToken creates a signed custom authentication token with the specified user ID. // // The resulting JWT can be used in a Firebase client SDK to trigger an authentication flow. See // https://firebase.google.com/docs/auth/admin/create-custom-tokens#sign_in_using_custom_tokens_on_clients // for more details on how to use custom tokens for client authentication. // // CustomToken follows the protocol outlined below to sign the generated tokens: // - If the SDK was initialized with service account credentials, uses the private key present in // the credentials to sign tokens locally. // - If a service account email was specified during initialization (via firebase.Config struct), // calls the IAMCredentials service with that email to sign tokens remotely. See // https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob. // - If the code is deployed in the Google App Engine standard environment, uses the App Identity // service to sign tokens. See https://cloud.google.com/appengine/docs/standard/go/reference#SignBytes. // - If the code is deployed in a different GCP-managed environment (e.g. Google Compute Engine), // uses the local Metadata server to auto discover a service account email. This is used in // conjunction with the IAM service to sign tokens remotely. // // CustomToken returns an error the SDK fails to discover a viable mechanism for signing tokens. func (c *baseClient) CustomToken(ctx context.Context, uid string) (string, error) { return c.CustomTokenWithClaims(ctx, uid, nil) } // CustomTokenWithClaims is similar to CustomToken, but in addition to the user ID, it also encodes // all the key-value pairs in the provided map as claims in the resulting JWT. func (c *baseClient) CustomTokenWithClaims(ctx context.Context, uid string, devClaims map[string]interface{}) (string, error) { iss, err := c.signer.Email(ctx) if err != nil { return "", err } if len(uid) == 0 || len(uid) > 128 { return "", errors.New("uid must be non-empty, and not longer than 128 characters") } var disallowed []string for _, k := range reservedClaims { if _, contains := devClaims[k]; contains { disallowed = append(disallowed, k) } } if len(disallowed) == 1 { return "", fmt.Errorf("developer claim %q is reserved and cannot be specified", disallowed[0]) } else if len(disallowed) > 1 { return "", fmt.Errorf("developer claims %q are reserved and cannot be specified", strings.Join(disallowed, ", ")) } now := c.clock.Now().Unix() info := &jwtInfo{ header: jwtHeader{Algorithm: c.signer.Algorithm(), Type: "JWT"}, payload: &customToken{ Iss: iss, Sub: iss, Aud: firebaseAudience, UID: uid, Iat: now, Exp: now + oneHourInSeconds, TenantID: c.tenantID, Claims: devClaims, }, } return info.Token(ctx, c.signer) } // SessionCookie creates a new Firebase session cookie from the given ID token and expiry // duration. The returned JWT can be set as a server-side session cookie with a custom cookie // policy. Expiry duration must be at least 5 minutes but may not exceed 14 days. func (c *Client) SessionCookie( ctx context.Context, idToken string, expiresIn time.Duration, ) (string, error) { return c.baseClient.createSessionCookie(ctx, idToken, expiresIn) } // Token represents a decoded Firebase ID token. // // Token provides typed accessors to the common JWT fields such as Audience (aud) and Expiry (exp). // Additionally it provides a UID field, which indicates the user ID of the account to which this token // belongs. Any additional JWT claims can be accessed via the Claims map of Token. type Token struct { AuthTime int64 `json:"auth_time"` Issuer string `json:"iss"` Audience string `json:"aud"` Expires int64 `json:"exp"` IssuedAt int64 `json:"iat"` Subject string `json:"sub,omitempty"` UID string `json:"uid,omitempty"` Firebase FirebaseInfo `json:"firebase"` Claims map[string]interface{} `json:"-"` } // FirebaseInfo represents the information about the sign-in event, including which auth provider // was used and provider-specific identity details. // // This data is provided by the Firebase Auth service and is a reserved claim in the ID token. type FirebaseInfo struct { SignInProvider string `json:"sign_in_provider"` Tenant string `json:"tenant"` Identities map[string]interface{} `json:"identities"` } // baseClient exposes the APIs common to both auth.Client and auth.TenantClient. type baseClient struct { userManagementEndpoint string providerConfigEndpoint string tenantMgtEndpoint string projectMgtEndpoint string projectID string tenantID string httpClient *internal.HTTPClient idTokenVerifier *tokenVerifier cookieVerifier *tokenVerifier signer cryptoSigner clock internal.Clock isEmulator bool } func (c *baseClient) withTenantID(tenantID string) *baseClient { copy := *c copy.tenantID = tenantID return © } // VerifyIDToken verifies the signature and payload of the provided ID token. // // VerifyIDToken accepts a signed JWT token string, and verifies that it is current, issued for the // correct Firebase project, and signed by the Google Firebase services in the cloud. It returns // a Token containing the decoded claims in the input JWT. See // https://firebase.google.com/docs/auth/admin/verify-id-tokens#retrieve_id_tokens_on_clients for // more details on how to obtain an ID token in a client app. // // In non-emulator mode, this function does not make any RPC calls most of the time. // The only time it makes an RPC call is when Google public keys need to be refreshed. // These keys get cached up to 24 hours, and therefore the RPC overhead gets amortized // over many invocations of this function. // // This does not check whether or not the token has been revoked or disabled. Use `VerifyIDTokenAndCheckRevoked()` // when a revocation check is needed. func (c *baseClient) VerifyIDToken(ctx context.Context, idToken string) (*Token, error) { return c.verifyIDToken(ctx, idToken, false) } // VerifyIDTokenAndCheckRevoked verifies the provided ID token, and additionally checks that the // token has not been revoked or disabled. // // Unlike `VerifyIDToken()`, this function must make an RPC call to perform the revocation check. // Developers are advised to take this additional overhead into consideration when including this // function in an authorization flow that gets executed often. func (c *baseClient) VerifyIDTokenAndCheckRevoked(ctx context.Context, idToken string) (*Token, error) { return c.verifyIDToken(ctx, idToken, true) } func (c *baseClient) verifyIDToken(ctx context.Context, idToken string, checkRevokedOrDisabled bool) (*Token, error) { decoded, err := c.idTokenVerifier.VerifyToken(ctx, idToken, c.isEmulator) if err != nil { return nil, err } if c.tenantID != "" && c.tenantID != decoded.Firebase.Tenant { return nil, &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: fmt.Sprintf("invalid tenant id: %q", decoded.Firebase.Tenant), Ext: map[string]interface{}{ authErrorCode: tenantIDMismatch, }, } } if c.isEmulator || checkRevokedOrDisabled { err = c.checkRevokedOrDisabled(ctx, decoded, idTokenRevoked, "ID token has been revoked") if err != nil { return nil, err } } return decoded, nil } // IsTenantIDMismatch checks if the given error was due to a mismatched tenant ID in a JWT. func IsTenantIDMismatch(err error) bool { return hasAuthErrorCode(err, tenantIDMismatch) } // IsIDTokenRevoked checks if the given error was due to a revoked ID token. // // When IsIDTokenRevoked returns true, IsIDTokenInvalid is guaranteed to return true. func IsIDTokenRevoked(err error) bool { return hasAuthErrorCode(err, idTokenRevoked) } // IsUserDisabled checks if the given error was due to a disabled ID token // // When IsUserDisabled returns true, IsIDTokenInvalid is guaranteed to return true. func IsUserDisabled(err error) bool { return hasAuthErrorCode(err, userDisabled) } // VerifySessionCookie verifies the signature and payload of the provided Firebase session cookie. // // VerifySessionCookie accepts a signed JWT token string, and verifies that it is current, issued for the // correct Firebase project, and signed by the Google Firebase services in the cloud. It returns a Token containing the // decoded claims in the input JWT. See https://firebase.google.com/docs/auth/admin/manage-cookies for more details on // how to obtain a session cookie. // // In non-emulator mode, this function does not make any RPC calls most of the time. // The only time it makes an RPC call is when Google public keys need to be refreshed. // These keys get cached up to 24 hours, and therefore the RPC overhead gets amortized // over many invocations of this function. // // This does not check whether or not the cookie has been revoked. Use `VerifySessionCookieAndCheckRevoked()` // when a revocation check is needed. func (c *Client) VerifySessionCookie(ctx context.Context, sessionCookie string) (*Token, error) { return c.verifySessionCookie(ctx, sessionCookie, false) } // VerifySessionCookieAndCheckRevoked verifies the provided session cookie, and additionally checks that the // cookie has not been revoked and the user has not been disabled. // // Unlike `VerifySessionCookie()`, this function must make an RPC call to perform the revocation check. // Developers are advised to take this additional overhead into consideration when including this // function in an authorization flow that gets executed often. func (c *Client) VerifySessionCookieAndCheckRevoked(ctx context.Context, sessionCookie string) (*Token, error) { return c.verifySessionCookie(ctx, sessionCookie, true) } func (c *Client) verifySessionCookie(ctx context.Context, sessionCookie string, checkRevokedOrDisabled bool) (*Token, error) { decoded, err := c.cookieVerifier.VerifyToken(ctx, sessionCookie, c.isEmulator) if err != nil { return nil, err } if c.isEmulator || checkRevokedOrDisabled { err := c.checkRevokedOrDisabled(ctx, decoded, sessionCookieRevoked, "session cookie has been revoked") if err != nil { return nil, err } } return decoded, nil } // IsSessionCookieRevoked checks if the given error was due to a revoked session cookie. // // When IsSessionCookieRevoked returns true, IsSessionCookieInvalid is guaranteed to return true. func IsSessionCookieRevoked(err error) bool { return hasAuthErrorCode(err, sessionCookieRevoked) } // checkRevokedOrDisabled checks whether the input token has been revoked or disabled. func (c *baseClient) checkRevokedOrDisabled(ctx context.Context, token *Token, errCode string, errMessage string) error { user, err := c.GetUser(ctx, token.UID) if err != nil { return err } if user.Disabled { return &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: "user has been disabled", Ext: map[string]interface{}{ authErrorCode: userDisabled, }, } } if token.IssuedAt*1000 < user.TokensValidAfterMillis { return &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: errMessage, Ext: map[string]interface{}{ authErrorCode: errCode, }, } } return nil } func hasAuthErrorCode(err error, code string) bool { fe, ok := err.(*internal.FirebaseError) if !ok { return false } got, ok := fe.Ext[authErrorCode] return ok && got == code } golang-google-firebase-go-4.18.0/auth/auth_appengine.go000066400000000000000000000021721505612111400227630ustar00rootroot00000000000000//go:build appengine // +build appengine // Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "firebase.google.com/go/v4/internal" "google.golang.org/appengine/v2" ) type aeSigner struct{} func newCryptoSigner(ctx context.Context, conf *internal.AuthConfig) (cryptoSigner, error) { return aeSigner{}, nil } func (s aeSigner) Email(ctx context.Context) (string, error) { return appengine.ServiceAccount(ctx) } func (s aeSigner) Sign(ctx context.Context, b []byte) ([]byte, error) { _, sig, err := appengine.SignBytes(ctx, b) return sig, err } golang-google-firebase-go-4.18.0/auth/auth_std.go000066400000000000000000000015321505612111400216060ustar00rootroot00000000000000//go:build !appengine // +build !appengine // Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "firebase.google.com/go/v4/internal" ) func newCryptoSigner(ctx context.Context, conf *internal.AuthConfig) (cryptoSigner, error) { return newIAMSigner(ctx, conf) } golang-google-firebase-go-4.18.0/auth/auth_test.go000066400000000000000000001300621505612111400217740ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/json" "errors" "fmt" "io/ioutil" "log" "net/http" "os" "strings" "syscall" "testing" "time" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" "golang.org/x/oauth2/google" "google.golang.org/api/option" "google.golang.org/api/transport" ) const ( credEnvVar = "GOOGLE_APPLICATION_CREDENTIALS" testProjectID = "mock-project-id" testVersion = "test-version" defaultIDToolkitV1Endpoint = "https://identitytoolkit.googleapis.com/v1" defaultIDToolkitV2Endpoint = "https://identitytoolkit.googleapis.com/v2" ) var ( testGetUserResponse []byte testGetDisabledUserResponse []byte testIDToken string testSessionCookie string testSigner cryptoSigner testIDTokenVerifier *tokenVerifier testCookieVerifier *tokenVerifier optsWithServiceAcct = []option.ClientOption{ option.WithCredentialsFile("../testdata/service_account.json"), } optsWithTokenSource = []option.ClientOption{ option.WithTokenSource(&internal.MockTokenSource{ AccessToken: "test.token", }), } testClock = &internal.MockClock{Timestamp: time.Now()} ) func TestMain(m *testing.M) { var err error testSigner, err = signerForTests(context.Background()) logFatal(err) testIDTokenVerifier, err = idTokenVerifierForTests(context.Background()) logFatal(err) testCookieVerifier, err = cookieVerifierForTests(context.Background()) logFatal(err) testGetUserResponse, err = ioutil.ReadFile("../testdata/get_user.json") logFatal(err) testGetDisabledUserResponse, err = ioutil.ReadFile("../testdata/get_disabled_user.json") logFatal(err) testIDToken = getIDToken(nil) testSessionCookie = getSessionCookie(nil) os.Exit(m.Run()) } func TestNewClientWithServiceAccountCredentials(t *testing.T) { creds, err := transport.Creds(context.Background(), optsWithServiceAcct...) if err != nil { t.Fatal(err) } client, err := NewClient(context.Background(), &internal.AuthConfig{ Opts: optsWithServiceAcct, ProjectID: creds.ProjectID, Version: testVersion, }) if err != nil { t.Fatal(err) } if _, ok := client.signer.(*serviceAccountSigner); !ok { t.Errorf("NewClient().signer = %#v; want = serviceAccountSigner", client.signer) } if err := checkIDTokenVerifier(client.idTokenVerifier, creds.ProjectID); err != nil { t.Errorf("NewClient().idTokenVerifier: %v", err) } if err := checkCookieVerifier(client.cookieVerifier, creds.ProjectID); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } if err := checkBaseClient(client, creds.ProjectID); err != nil { t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) } } func TestNewClientWithoutCredentials(t *testing.T) { conf := &internal.AuthConfig{ Opts: optsWithTokenSource, Version: testVersion, } client, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } if _, ok := client.signer.(*iamSigner); !ok { t.Errorf("NewClient().signer = %#v; want = iamSigner", client.signer) } if err := checkIDTokenVerifier(client.idTokenVerifier, ""); err != nil { t.Errorf("NewClient().idTokenVerifier = %v; want = nil", err) } if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } if err := checkBaseClient(client, ""); err != nil { t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) } } func TestNewClientWithServiceAccountID(t *testing.T) { conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ServiceAccountID: "explicit-service-account", Version: testVersion, } client, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } if _, ok := client.signer.(*iamSigner); !ok { t.Errorf("NewClient().signer = %#v; want = iamSigner", client.signer) } if err := checkIDTokenVerifier(client.idTokenVerifier, ""); err != nil { t.Errorf("NewClient().idTokenVerifier = %v; want = nil", err) } if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } if err := checkBaseClient(client, ""); err != nil { t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) } email, err := client.signer.Email(context.Background()) if email != conf.ServiceAccountID || err != nil { t.Errorf("Email() = (%q, %v); want = (%q, nil)", email, err, conf.ServiceAccountID) } } func TestNewClientWithUserCredentials(t *testing.T) { creds := &google.DefaultCredentials{ JSON: []byte(`{ "client_id": "test-client", "client_secret": "test-secret" }`), } conf := &internal.AuthConfig{ Opts: []option.ClientOption{option.WithCredentials(creds)}, Version: testVersion, } client, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } if _, ok := client.signer.(*iamSigner); !ok { t.Errorf("NewClient().signer = %#v; want = iamSigner", client.signer) } if err := checkIDTokenVerifier(client.idTokenVerifier, ""); err != nil { t.Errorf("NewClient().idTokenVerifier = %v; want = nil", err) } if err := checkCookieVerifier(client.cookieVerifier, ""); err != nil { t.Errorf("NewClient().cookieVerifier: %v", err) } if err := checkBaseClient(client, ""); err != nil { t.Errorf("NewClient().baseClient: %v", err) } if client.clock != internal.SystemClock { t.Errorf("NewClient().clock = %v; want = SystemClock", client.clock) } } func TestNewClientWithMalformedCredentials(t *testing.T) { creds := &google.DefaultCredentials{ JSON: []byte("not json"), } conf := &internal.AuthConfig{ Opts: []option.ClientOption{ option.WithCredentials(creds), }, } if c, err := NewClient(context.Background(), conf); c != nil || err == nil { t.Errorf("NewClient() = (%v,%v); want = (nil, error)", c, err) } } func TestNewClientWithInvalidPrivateKey(t *testing.T) { sa := map[string]interface{}{ "private_key": "not-a-private-key", "client_email": "foo@bar", } b, err := json.Marshal(sa) if err != nil { t.Fatal(err) } creds := &google.DefaultCredentials{JSON: b} conf := &internal.AuthConfig{ Opts: []option.ClientOption{ option.WithCredentials(creds), }, } if c, err := NewClient(context.Background(), conf); c != nil || err == nil { t.Errorf("NewClient() = (%v,%v); want = (nil, error)", c, err) } } func TestNewClientAppDefaultCredentialsWithInvalidFile(t *testing.T) { current := os.Getenv(credEnvVar) if err := os.Setenv(credEnvVar, "../testdata/non_existing.json"); err != nil { t.Fatal(err) } defer os.Setenv(credEnvVar, current) conf := &internal.AuthConfig{} if c, err := NewClient(context.Background(), conf); c != nil || err == nil { t.Errorf("Auth() = (%v, %v); want (nil, error)", c, err) } } func TestNewClientInvalidCredentialFile(t *testing.T) { invalidFiles := []string{ "testdata", "testdata/plain_text.txt", } ctx := context.Background() for _, tc := range invalidFiles { conf := &internal.AuthConfig{ Opts: []option.ClientOption{ option.WithCredentialsFile(tc), }, } if c, err := NewClient(ctx, conf); c != nil || err == nil { t.Errorf("Auth() = (%v, %v); want (nil, error)", c, err) } } } func TestNewClientExplicitNoAuth(t *testing.T) { ctx := context.Background() conf := &internal.AuthConfig{ Opts: []option.ClientOption{ option.WithoutAuthentication(), }, } if c, err := NewClient(ctx, conf); c == nil || err != nil { t.Errorf("Auth() = (%v, %v); want (auth, nil)", c, err) } } func TestNewClientEmulatorHostEnvVar(t *testing.T) { emulatorHost := "localhost:9099" idToolkitV1Endpoint := "http://localhost:9099/identitytoolkit.googleapis.com/v1" idToolkitV2Endpoint := "http://localhost:9099/identitytoolkit.googleapis.com/v2" os.Setenv(emulatorHostEnvVar, emulatorHost) defer os.Unsetenv(emulatorHostEnvVar) client, err := NewClient(context.Background(), &internal.AuthConfig{}) if err != nil { t.Fatal(err) } baseClient := client.baseClient if baseClient.userManagementEndpoint != idToolkitV1Endpoint { t.Errorf("baseClient.userManagementEndpoint = %q; want = %q", baseClient.userManagementEndpoint, idToolkitV1Endpoint) } if baseClient.providerConfigEndpoint != idToolkitV2Endpoint { t.Errorf("baseClient.providerConfigEndpoint = %q; want = %q", baseClient.providerConfigEndpoint, idToolkitV2Endpoint) } if baseClient.tenantMgtEndpoint != idToolkitV2Endpoint { t.Errorf("baseClient.tenantMgtEndpoint = %q; want = %q", baseClient.tenantMgtEndpoint, idToolkitV2Endpoint) } if _, ok := baseClient.signer.(emulatedSigner); !ok { t.Errorf("baseClient.signer = %#v; want = %#v", baseClient.signer, emulatedSigner{}) } } func TestCustomToken(t *testing.T) { client := &Client{ baseClient: &baseClient{ signer: testSigner, clock: testClock, }, } token, err := client.CustomToken(context.Background(), "user1") if err != nil { t.Fatal(err) } if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil { t.Fatal(err) } } func TestCustomTokenWithClaims(t *testing.T) { client := &Client{ baseClient: &baseClient{ signer: testSigner, clock: testClock, }, } claims := map[string]interface{}{ "foo": "bar", "premium": true, "count": float64(123), } token, err := client.CustomTokenWithClaims(context.Background(), "user1", claims) if err != nil { t.Fatal(err) } if err := verifyCustomToken(context.Background(), token, claims, ""); err != nil { t.Fatal(err) } } func TestCustomTokenWithNilClaims(t *testing.T) { client := &Client{ baseClient: &baseClient{ signer: testSigner, clock: testClock, }, } token, err := client.CustomTokenWithClaims(context.Background(), "user1", nil) if err != nil { t.Fatal(err) } if err := verifyCustomToken(context.Background(), token, nil, ""); err != nil { t.Fatal(err) } } func TestCustomTokenForTenant(t *testing.T) { client := &Client{ baseClient: &baseClient{ tenantID: "tenantID", signer: testSigner, clock: testClock, }, } claims := map[string]interface{}{ "foo": "bar", "premium": true, } token, err := client.CustomTokenWithClaims(context.Background(), "user1", claims) if err != nil { t.Fatal(err) } if err := verifyCustomToken(context.Background(), token, claims, "tenantID"); err != nil { t.Fatal(err) } } func TestCustomTokenError(t *testing.T) { cases := []struct { name string uid string claims map[string]interface{} }{ {"EmptyName", "", nil}, {"LongUid", strings.Repeat("a", 129), nil}, {"ReservedClaim", "uid", map[string]interface{}{"sub": "1234"}}, {"ReservedClaims", "uid", map[string]interface{}{"sub": "1234", "aud": "foo"}}, } client := &baseClient{ signer: testSigner, clock: testClock, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { token, err := client.CustomTokenWithClaims(context.Background(), tc.uid, tc.claims) if token != "" || err == nil { t.Errorf("CustomTokenWithClaims(%q) = (%q, %v); want = (\"\", error)", tc.name, token, err) } }) } } func TestCustomTokenInvalidCredential(t *testing.T) { ctx := context.Background() conf := &internal.AuthConfig{ Opts: optsWithTokenSource, } s, err := NewClient(ctx, conf) if err != nil { t.Fatal(err) } s.signer.(*iamSigner).httpClient.RetryConfig = nil token, err := s.CustomToken(ctx, "user1") if token != "" || err == nil { t.Errorf("CustomTokenWithClaims() = (%q, %v); want = (\"\", error)", token, err) } token, err = s.CustomTokenWithClaims(ctx, "user1", map[string]interface{}{"foo": "bar"}) if token != "" || err == nil { t.Errorf("CustomTokenWithClaims() = (%q, %v); want = (\"\", error)", token, err) } } func TestVerifyIDToken(t *testing.T) { client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } ft, err := client.VerifyIDToken(context.Background(), testIDToken) if err != nil { t.Fatal(err) } now := testClock.Now().Unix() if ft.AuthTime != now-100 { t.Errorf("AuthTime = %d; want = %d", ft.AuthTime, now-100) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "") } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestVerifyIDTokenFromTenant(t *testing.T) { client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } idToken := getIDToken(mockIDTokenPayload{ "firebase": map[string]interface{}{ "tenant": "tenantID", "sign_in_provider": "custom", }, }) ft, err := client.VerifyIDToken(context.Background(), idToken) if err != nil { t.Fatal(err) } now := testClock.Now().Unix() if ft.AuthTime != now-100 { t.Errorf("AuthTime = %d; want = %d", ft.AuthTime, now-100) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "tenantID" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "tenantID") } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestVerifyIDTokenClockSkew(t *testing.T) { now := testClock.Now().Unix() cases := []struct { name string token string }{ {"FutureToken", getIDToken(mockIDTokenPayload{"iat": now + clockSkewSeconds - 1})}, {"ExpiredToken", getIDToken(mockIDTokenPayload{ "iat": now - 1000, "exp": now - clockSkewSeconds + 1, })}, } client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { ft, err := client.VerifyIDToken(context.Background(), tc.token) if err != nil { t.Fatalf("VerifyIDToken(%q) = (%q, %v); want = (token, nil)", tc.name, ft, err) } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } }) } } func TestVerifyIDTokenInvalidSignature(t *testing.T) { client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } parts := strings.Split(testIDToken, ".") token := fmt.Sprintf("%s:%s:invalidsignature", parts[0], parts[1]) ft, err := client.VerifyIDToken(context.Background(), token) if ft != nil || !IsIDTokenInvalid(err) { t.Errorf("VerifyIDToken('invalid-signature') = (%v, %v); want = (nil, IDTokenInvalid)", ft, err) } } func TestVerifyIDTokenError(t *testing.T) { now := testClock.Now().Unix() cases := []struct { name, token, want string }{ { name: "NoKid", token: getIDTokenWithKid("", nil), want: "ID token has no 'kid' header", }, { name: "WrongKid", token: getIDTokenWithKid("foo", nil), want: "failed to verify token signature", }, { name: "BadAudience", token: getIDToken(mockIDTokenPayload{"aud": "bad-audience"}), want: `ID token has invalid 'aud' (audience) claim; expected "mock-project-id" but ` + `got "bad-audience"; make sure the ID token comes from the same Firebase ` + `project as the credential used to authenticate this SDK; see ` + `https://firebase.google.com/docs/auth/admin/verify-id-tokens for details on how ` + `to retrieve a valid ID token`, }, { name: "BadIssuer", token: getIDToken(mockIDTokenPayload{"iss": "bad-issuer"}), want: `ID token has invalid 'iss' (issuer) claim; expected ` + `"https://securetoken.google.com/mock-project-id" but got "bad-issuer"; make sure the ` + `ID token comes from the same Firebase project as the credential used to authenticate ` + `this SDK; see https://firebase.google.com/docs/auth/admin/verify-id-tokens for ` + `details on how to retrieve a valid ID token`, }, { name: "EmptySubject", token: getIDToken(mockIDTokenPayload{"sub": ""}), want: "ID token has empty 'sub' (subject) claim", }, { name: "NonStringSubject", token: getIDToken(mockIDTokenPayload{"sub": 10}), want: "json: cannot unmarshal number into Go struct field Token.sub of type string", }, { name: "TooLongSubject", token: getIDToken(mockIDTokenPayload{"sub": strings.Repeat("a", 129)}), want: "ID token has a 'sub' (subject) claim longer than 128 characters", }, { name: "FutureToken", token: getIDToken(mockIDTokenPayload{"iat": now + clockSkewSeconds + 1}), want: "ID token issued at future timestamp", }, { name: "ExpiredToken", token: getIDToken(mockIDTokenPayload{ "iat": now - 1000, "exp": now - clockSkewSeconds - 1, }), want: "ID token has expired", }, { name: "EmptyToken", token: "", want: "ID token must be a non-empty string", }, { name: "TooFewSegments", token: "foo", want: "incorrect number of segments", }, { name: "TooManySegments", token: "fo.ob.ar.baz", want: "incorrect number of segments", }, { name: "MalformedToken", token: "foo.bar.baz", want: "invalid character", }, } client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, err := client.VerifyIDToken(context.Background(), tc.token) if !IsIDTokenInvalid(err) || !strings.HasPrefix(err.Error(), tc.want) { t.Errorf("VerifyIDToken(%q) = %v; want = %q", tc.name, err, tc.want) } if tc.name == "ExpiredToken" && !IsIDTokenExpired(err) { t.Errorf("VerifyIDToken(%q) = %v; want = IDTokenExpired", tc.name, err) } }) } } func TestVerifyIDTokenInvalidAlgorithm(t *testing.T) { var payload mockIDTokenPayload segments := strings.Split(testIDToken, ".") if err := decode(segments[1], &payload); err != nil { t.Fatal(err) } info := &jwtInfo{ header: jwtHeader{ Algorithm: "HS256", Type: "JWT", KeyID: "mock-key-id-1", }, payload: payload, } token, err := info.Token(context.Background(), testSigner) if err != nil { t.Fatal(err) } client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } _, err = client.VerifyIDToken(context.Background(), token) if !IsIDTokenInvalid(err) { t.Errorf("VerifyIDToken(InvalidAlgorithm) = nil; want = IDTokenInvalid") } } func TestVerifyIDTokenWithNoProjectID(t *testing.T) { conf := &internal.AuthConfig{ ProjectID: "", Opts: optsWithTokenSource, } c, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } c.idTokenVerifier.keySource = testIDTokenVerifier.keySource if _, err := c.VerifyIDToken(context.Background(), testIDToken); err == nil { t.Error("VeridyIDToken() = nil; want error") } } func TestVerifyIDTokenUnsigned(t *testing.T) { token := getEmulatedIDToken(nil) client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, }, } _, err := client.VerifyIDToken(context.Background(), token) if !IsIDTokenInvalid(err) { t.Errorf("VerifyIDToken(Unsigned) = %v; want = IDTokenInvalid", err) } } func TestEmulatorVerifyIDToken(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.idTokenVerifier = testIDTokenVerifier s.Client.isEmulator = true token := getEmulatedIDToken(nil) ft, err := s.Client.VerifyIDToken(context.Background(), token) if err != nil { t.Fatal(err) } now := testClock.Now().Unix() if ft.AuthTime != now-100 { t.Errorf("AuthTime = %d; want = %d", ft.AuthTime, now-100) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "") } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestEmulatorVerifyIDTokenExpiredError(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.idTokenVerifier = testIDTokenVerifier s.Client.isEmulator = true now := testClock.Now().Unix() token := getEmulatedIDToken(mockIDTokenPayload{ "iat": now - 1000, "exp": now - clockSkewSeconds - 1, }) _, err := s.Client.VerifyIDToken(context.Background(), token) if !IsIDTokenExpired(err) { t.Errorf("VerifyIDToken(Expired) = %v; want = IDTokenExpired", err) } } func TestEmulatorVerifyIDTokenUnreachableEmulator(t *testing.T) { conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ProjectID: testProjectID, Version: testVersion, } client, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } client.httpClient.Client.Transport = eConnRefusedTransport{} client.httpClient.RetryConfig = nil client.isEmulator = true token := getEmulatedIDToken(nil) _, err = client.VerifyIDToken(context.Background(), token) if err == nil || !errorutils.IsUnavailable(err) || !strings.HasPrefix(err.Error(), "failed to establish a connection") { t.Errorf("VerifyIDToken(UnreachableEmulator) = %v; want = Unavailable", err) } } func TestCustomTokenVerification(t *testing.T) { client := &Client{ baseClient: &baseClient{ idTokenVerifier: testIDTokenVerifier, signer: testSigner, clock: testClock, }, } token, err := client.CustomToken(context.Background(), "user1") if err != nil { t.Fatal(err) } if _, err := client.VerifyIDToken(context.Background(), token); !IsIDTokenInvalid(err) { t.Error("VeridyIDToken() = nil; want = IDTokenInvalid") } } func TestCertificateRequestError(t *testing.T) { tv, err := newIDTokenVerifier(context.Background(), testProjectID) if err != nil { t.Fatal(err) } tv.keySource = &mockKeySource{nil, errors.New("mock error")} client := &Client{ baseClient: &baseClient{ idTokenVerifier: tv, }, } if _, err := client.VerifyIDToken(context.Background(), testIDToken); !IsCertificateFetchFailed(err) { t.Error("VeridyIDToken() = nil; want = CertificateFetchFailed") } } func TestVerifyIDTokenAndCheckRevoked(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.idTokenVerifier = testIDTokenVerifier ft, err := s.Client.VerifyIDTokenAndCheckRevoked(context.Background(), testIDToken) if err != nil { t.Fatal(err) } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestVerifyIDTokenDoesNotCheckRevoked(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() revokedToken := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.idTokenVerifier = testIDTokenVerifier ft, err := s.Client.VerifyIDToken(context.Background(), revokedToken) if err != nil { t.Fatal(err) } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestInvalidTokenDoesNotCheckRevokedOrDisabled(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.idTokenVerifier = testIDTokenVerifier ft, err := s.Client.VerifyIDTokenAndCheckRevoked(context.Background(), "") if ft != nil || !IsIDTokenInvalid(err) || IsIDTokenRevoked(err) || IsUserDisabled(err) { t.Errorf("VerifyIDTokenAndCheckRevoked() = (%v, %v); want = (nil, IDTokenInvalid)", ft, err) } if len(s.Req) != 0 { t.Errorf("Revocation checks = %d; want = 0", len(s.Req)) } } func TestVerifyIDTokenAndCheckRevokedError(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() revokedToken := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.idTokenVerifier = testIDTokenVerifier p, err := s.Client.VerifyIDTokenAndCheckRevoked(context.Background(), revokedToken) we := "ID token has been revoked" if p != nil || !IsIDTokenRevoked(err) || !IsIDTokenInvalid(err) || err.Error() != we { t.Errorf("VerifyIDTokenAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, %v)", p, err, nil, we) } } func TestVerifyIDTokenAndCheckDisabledError(t *testing.T) { s := echoServer(testGetDisabledUserResponse, t) defer s.Close() revokedToken := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.idTokenVerifier = testIDTokenVerifier p, err := s.Client.VerifyIDTokenAndCheckRevoked(context.Background(), revokedToken) we := "user has been disabled" if p != nil || !IsUserDisabled(err) || !IsIDTokenInvalid(err) || err.Error() != we { t.Errorf("VerifyIDTokenAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, %v)", p, err, nil, we) } } func TestIDTokenRevocationCheckUserMgtError(t *testing.T) { resp := `{ "kind" : "identitytoolkit#GetAccountInfoResponse", "users" : [] }` s := echoServer([]byte(resp), t) defer s.Close() revokedToken := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.idTokenVerifier = testIDTokenVerifier p, err := s.Client.VerifyIDTokenAndCheckRevoked(context.Background(), revokedToken) if p != nil || !IsUserNotFound(err) { t.Errorf("VerifyIDTokenAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, UserNotFound)", p, err, nil) } } func TestVerifySessionCookie(t *testing.T) { client := &Client{ baseClient: &baseClient{ cookieVerifier: testCookieVerifier, }, } ft, err := client.VerifySessionCookie(context.Background(), testSessionCookie) if err != nil { t.Fatal(err) } now := testClock.Now().Unix() if ft.AuthTime != now-100 { t.Errorf("AuthTime = %d; want = %d", ft.AuthTime, now-100) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "") } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestVerifySessionCookieFromTenant(t *testing.T) { client := &Client{ baseClient: &baseClient{ cookieVerifier: testCookieVerifier, }, } cookie := getSessionCookie(mockIDTokenPayload{ "firebase": map[string]interface{}{ "tenant": "tenantID", "sign_in_provider": "custom", }, }) ft, err := client.VerifySessionCookie(context.Background(), cookie) if err != nil { t.Fatal(err) } now := testClock.Now().Unix() if ft.AuthTime != now-100 { t.Errorf("AuthTime = %d; want = %d", ft.AuthTime, now-100) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "tenantID" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "tenantID") } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestVerifySessionCookieError(t *testing.T) { now := testClock.Now().Unix() cases := []struct { name, token, want string }{ { name: "BadAudience", token: getSessionCookie(mockIDTokenPayload{"aud": "bad-audience"}), want: `session cookie has invalid 'aud' (audience) claim; expected "mock-project-id" but ` + `got "bad-audience"; make sure the session cookie comes from the same Firebase ` + `project as the credential used to authenticate this SDK; see ` + `https://firebase.google.com/docs/auth/admin/manage-cookies for details on how ` + `to retrieve a valid session cookie`, }, { name: "BadIssuer", token: getSessionCookie(mockIDTokenPayload{"iss": "bad-issuer"}), want: `session cookie has invalid 'iss' (issuer) claim; expected ` + `"https://session.firebase.google.com/mock-project-id" but got "bad-issuer"; make sure the ` + `session cookie comes from the same Firebase project as the credential used to authenticate ` + `this SDK; see https://firebase.google.com/docs/auth/admin/manage-cookies for ` + `details on how to retrieve a valid session cookie`, }, { name: "EmptySubject", token: getSessionCookie(mockIDTokenPayload{"sub": ""}), want: "session cookie has empty 'sub' (subject) claim", }, { name: "NonStringSubject", token: getSessionCookie(mockIDTokenPayload{"sub": 10}), want: "json: cannot unmarshal number into Go struct field Token.sub of type string", }, { name: "TooLongSubject", token: getSessionCookie(mockIDTokenPayload{"sub": strings.Repeat("a", 129)}), want: "session cookie has a 'sub' (subject) claim longer than 128 characters", }, { name: "FutureToken", token: getSessionCookie(mockIDTokenPayload{"iat": now + clockSkewSeconds + 1}), want: "session cookie issued at future timestamp", }, { name: "ExpiredToken", token: getSessionCookie(mockIDTokenPayload{ "iat": now - 1000, "exp": now - clockSkewSeconds - 1, }), want: "session cookie has expired", }, { name: "EmptyToken", token: "", want: "session cookie must be a non-empty string", }, { name: "TooFewSegments", token: "foo", want: "incorrect number of segments", }, { name: "TooManySegments", token: "fo.ob.ar.baz", want: "incorrect number of segments", }, { name: "MalformedToken", token: "foo.bar.baz", want: "invalid character", }, } client := &Client{ baseClient: &baseClient{ cookieVerifier: testCookieVerifier, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, err := client.VerifySessionCookie(context.Background(), tc.token) if !IsSessionCookieInvalid(err) || !strings.HasPrefix(err.Error(), tc.want) { t.Errorf("VerifySessionCookie(%q) = %v; want = %q", tc.name, err, tc.want) } if tc.name == "ExpiredToken" && !IsSessionCookieExpired(err) { t.Errorf("VerifySessionCookie(%q) = %v; want = SessionCookieExpired", tc.name, err) } }) } } func TestVerifySessionCookieDoesNotCheckRevoked(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() revokedCookie := getSessionCookie(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.cookieVerifier = testCookieVerifier ft, err := s.Client.VerifySessionCookie(context.Background(), revokedCookie) if err != nil { t.Fatal(err) } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestVerifySessionCookieAndCheckRevoked(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.cookieVerifier = testCookieVerifier ft, err := s.Client.VerifySessionCookieAndCheckRevoked(context.Background(), testSessionCookie) if err != nil { t.Fatal(err) } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestInvalidCookieDoesNotCheckRevoked(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.cookieVerifier = testCookieVerifier ft, err := s.Client.VerifySessionCookieAndCheckRevoked(context.Background(), "") if ft != nil || !IsSessionCookieInvalid(err) { t.Errorf("VerifySessionCookieAndCheckRevoked() = (%v, %v); want = (nil, SessionCookieInvalid)", ft, err) } if len(s.Req) != 0 { t.Errorf("Revocation checks = %d; want = 0", len(s.Req)) } } func TestVerifySessionCookieAndCheckRevokedError(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() revokedCookie := getSessionCookie(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.cookieVerifier = testCookieVerifier p, err := s.Client.VerifySessionCookieAndCheckRevoked(context.Background(), revokedCookie) we := "session cookie has been revoked" if p != nil || !IsSessionCookieRevoked(err) || !IsSessionCookieInvalid(err) || err.Error() != we { t.Errorf("VerifySessionCookieAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, %v)", p, err, nil, we) } } func TestVerifySessionCookieAndCheckDisabledError(t *testing.T) { s := echoServer(testGetDisabledUserResponse, t) defer s.Close() revokedCookie := getSessionCookie(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.cookieVerifier = testCookieVerifier p, err := s.Client.VerifySessionCookieAndCheckRevoked(context.Background(), revokedCookie) we := "user has been disabled" if p != nil || !IsUserDisabled(err) || !IsSessionCookieInvalid(err) || err.Error() != we { t.Errorf("VerifySessionCookieAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, %v)", p, err, nil, we) } } func TestCookieRevocationCheckUserMgtError(t *testing.T) { resp := `{ "kind" : "identitytoolkit#GetAccountInfoResponse", "users" : [] }` s := echoServer([]byte(resp), t) defer s.Close() revokedCookie := getSessionCookie(mockIDTokenPayload{"uid": "uid", "iat": 1970}) s.Client.cookieVerifier = testCookieVerifier p, err := s.Client.VerifySessionCookieAndCheckRevoked(context.Background(), revokedCookie) if p != nil || !IsUserNotFound(err) { t.Errorf("VerifySessionCookieAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, UserNotFound)", p, err, nil) } } func TestVerifySessionCookieUnsigned(t *testing.T) { token := getEmulatedSessionCookie(nil) client := &Client{ baseClient: &baseClient{ cookieVerifier: testCookieVerifier, }, } _, err := client.VerifySessionCookie(context.Background(), token) if !IsSessionCookieInvalid(err) { t.Errorf("VerifySessionCookie(Unsigned) = %v; want = IDTokenInvalid", err) } } func TestEmulatorVerifySessionCookie(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.cookieVerifier = testCookieVerifier s.Client.isEmulator = true token := getEmulatedSessionCookie(nil) ft, err := s.Client.VerifySessionCookie(context.Background(), token) if err != nil { t.Fatal(err) } now := testClock.Now().Unix() if ft.AuthTime != now-100 { t.Errorf("AuthTime = %d; want = %d", ft.AuthTime, now-100) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "") } if ft.Claims["admin"] != true { t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"]) } if ft.UID != ft.Subject { t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject) } } func TestEmulatorVerifySessionCookieExpiredError(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.cookieVerifier = testCookieVerifier s.Client.isEmulator = true now := testClock.Now().Unix() token := getEmulatedSessionCookie(mockIDTokenPayload{ "iat": now - 1000, "exp": now - clockSkewSeconds - 1, }) _, err := s.Client.VerifySessionCookie(context.Background(), token) if !IsSessionCookieExpired(err) { t.Errorf("VerifySessionCookie(Expired) = %v; want = IDTokenExpired", err) } } func TestEmulatorVerifySessionCookieUnreachableEmulator(t *testing.T) { conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ProjectID: testProjectID, Version: testVersion, } client, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } client.httpClient.Client.Transport = eConnRefusedTransport{} client.httpClient.RetryConfig = nil client.isEmulator = true token := getEmulatedSessionCookie(nil) _, err = client.VerifySessionCookie(context.Background(), token) if err == nil || !errorutils.IsUnavailable(err) || !strings.HasPrefix(err.Error(), "failed to establish a connection") { t.Errorf("VerifyIDToken(UnreachableEmulator) = %v; want = Unavailable", err) } } func signerForTests(ctx context.Context) (cryptoSigner, error) { creds, err := transport.Creds(ctx, optsWithServiceAcct...) if err != nil { return nil, err } return signerFromCreds(creds.JSON) } func idTokenVerifierForTests(ctx context.Context) (*tokenVerifier, error) { tv, err := newIDTokenVerifier(ctx, testProjectID) if err != nil { return nil, err } ks, err := newMockKeySource("../testdata/public_certs.json") if err != nil { return nil, err } tv.keySource = ks tv.clock = testClock return tv, nil } func cookieVerifierForTests(ctx context.Context) (*tokenVerifier, error) { tv, err := newSessionCookieVerifier(ctx, testProjectID) if err != nil { return nil, err } ks, err := newMockKeySource("../testdata/public_certs.json") if err != nil { return nil, err } tv.keySource = ks tv.clock = testClock return tv, nil } // mockKeySource provides access to a set of in-memory public keys. type mockKeySource struct { keys []*publicKey err error } func newMockKeySource(filePath string) (*mockKeySource, error) { certs, err := ioutil.ReadFile(filePath) if err != nil { return nil, err } keys, err := parsePublicKeys(certs) if err != nil { return nil, err } return &mockKeySource{ keys: keys, }, nil } func (k *mockKeySource) Keys(ctx context.Context) ([]*publicKey, error) { return k.keys, k.err } type mockIDTokenPayload map[string]interface{} func (p mockIDTokenPayload) decodeFrom(s string) error { return decode(s, &p) } type eConnRefusedTransport struct{} func (eConnRefusedTransport) RoundTrip(*http.Request) (*http.Response, error) { return nil, syscall.ECONNREFUSED } func getSessionCookie(p mockIDTokenPayload) string { return getSessionCookieWithSigner(testSigner, p) } func getEmulatedSessionCookie(p mockIDTokenPayload) string { return getSessionCookieWithSigner(emulatedSigner{}, p) } func getSessionCookieWithSigner(signer cryptoSigner, p mockIDTokenPayload) string { pCopy := map[string]interface{}{ "iss": "https://session.firebase.google.com/" + testProjectID, } for k, v := range p { pCopy[k] = v } return getIDTokenWithSigner(signer, pCopy) } func getIDTokenWithSigner(signer cryptoSigner, p mockIDTokenPayload) string { return getIDTokenWithSignerAndKid(signer, "mock-key-id-1", p) } func getIDToken(p mockIDTokenPayload) string { return getIDTokenWithSigner(testSigner, p) } func getIDTokenWithKid(kid string, p mockIDTokenPayload) string { return getIDTokenWithSignerAndKid(testSigner, kid, p) } func getEmulatedIDToken(p mockIDTokenPayload) string { return getIDTokenWithSignerAndKid(emulatedSigner{}, "mock-key-id-1", p) } func getIDTokenWithSignerAndKid(signer cryptoSigner, kid string, p mockIDTokenPayload) string { pCopy := mockIDTokenPayload{ "aud": testProjectID, "iss": "https://securetoken.google.com/" + testProjectID, "iat": testClock.Now().Unix() - 100, "exp": testClock.Now().Unix() + 3600, "auth_time": testClock.Now().Unix() - 100, "sub": "1234567890", "firebase": map[string]interface{}{ "identities": map[string]interface{}{}, "sign_in_provider": "custom", }, "admin": true, } for k, v := range p { pCopy[k] = v } info := &jwtInfo{ header: jwtHeader{ Algorithm: signer.Algorithm(), Type: "JWT", KeyID: kid, }, payload: pCopy, } token, err := info.Token(context.Background(), signer) logFatal(err) return token } func checkIDTokenVerifier(tv *tokenVerifier, projectID string) error { if tv == nil { return errors.New("tokenVerifier not initialized") } if tv.projectID != projectID { return fmt.Errorf("projectID = %q; want = %q", tv.projectID, projectID) } if tv.shortName != "ID token" { return fmt.Errorf("shortName = %q; want = %q", tv.shortName, "ID token") } if tv.invalidTokenCode != idTokenInvalid { return fmt.Errorf("invalidTokenCode = %q; want = %q", tv.invalidTokenCode, idTokenInvalid) } if tv.expiredTokenCode != idTokenExpired { return fmt.Errorf("expiredTokenCode = %q; want = %q", tv.expiredTokenCode, idTokenExpired) } return nil } func checkCookieVerifier(tv *tokenVerifier, projectID string) error { if tv == nil { return errors.New("tokenVerifier not initialized") } if tv.projectID != projectID { return fmt.Errorf("projectID = %q; want = %q", tv.projectID, projectID) } if tv.shortName != "session cookie" { return fmt.Errorf("shortName = %q; want = %q", tv.shortName, "session cookie") } if tv.invalidTokenCode != sessionCookieInvalid { return fmt.Errorf("invalidTokenCode = %q; want = %q", tv.invalidTokenCode, sessionCookieInvalid) } if tv.expiredTokenCode != sessionCookieExpired { return fmt.Errorf("expiredTokenCode = %q; want = %q", tv.expiredTokenCode, sessionCookieExpired) } return nil } func checkBaseClient(client *Client, wantProjectID string) error { baseClient := client.baseClient if baseClient.userManagementEndpoint != defaultIDToolkitV1Endpoint { return fmt.Errorf("userManagementEndpoint = %q; want = %q", baseClient.userManagementEndpoint, defaultIDToolkitV1Endpoint) } if baseClient.providerConfigEndpoint != defaultIDToolkitV2Endpoint { return fmt.Errorf("providerConfigEndpoint = %q; want = %q", baseClient.providerConfigEndpoint, defaultIDToolkitV2Endpoint) } if baseClient.tenantMgtEndpoint != defaultIDToolkitV2Endpoint { return fmt.Errorf("providerConfigEndpoint = %q; want = %q", baseClient.providerConfigEndpoint, defaultIDToolkitV2Endpoint) } if baseClient.projectID != wantProjectID { return fmt.Errorf("projectID = %q; want = %q", baseClient.projectID, wantProjectID) } req, err := http.NewRequest(http.MethodGet, "https://firebase.google.com", nil) if err != nil { return err } for _, opt := range baseClient.httpClient.Opts { opt(req) } version := req.Header.Get("X-Client-Version") wantVersion := fmt.Sprintf("Go/Admin/%s", testVersion) if version != wantVersion { return fmt.Errorf("version = %q; want = %q", version, wantVersion) } xGoogAPIClientHeader := internal.GetMetricsHeader(testVersion) if h := req.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { return fmt.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } return nil } func verifyCustomToken( ctx context.Context, token string, expected map[string]interface{}, tenantID string) error { if err := testIDTokenVerifier.verifySignature(ctx, token); err != nil { return err } var ( header jwtHeader payload customToken ) segments := strings.Split(token, ".") if err := decode(segments[0], &header); err != nil { return err } if err := decode(segments[1], &payload); err != nil { return err } email, err := testSigner.Email(ctx) if err != nil { return err } if header.Algorithm != "RS256" { return fmt.Errorf("Algorithm: %q; want: 'RS256'", header.Algorithm) } else if header.Type != "JWT" { return fmt.Errorf("Type: %q; want: 'JWT'", header.Type) } else if payload.Aud != firebaseAudience { return fmt.Errorf("Audience: %q; want: %q", payload.Aud, firebaseAudience) } else if payload.Iss != email { return fmt.Errorf("Issuer: %q; want: %q", payload.Iss, email) } else if payload.Sub != email { return fmt.Errorf("Subject: %q; want: %q", payload.Sub, email) } now := testClock.Now().Unix() if payload.Exp != now+3600 { return fmt.Errorf("Exp: %d; want: %d", payload.Exp, now+3600) } if payload.Iat != now { return fmt.Errorf("Iat: %d; want: %d", payload.Iat, now) } for k, v := range expected { if payload.Claims[k] != v { return fmt.Errorf("Claim[%q]: %v; want: %v", k, payload.Claims[k], v) } } if payload.TenantID != tenantID { return fmt.Errorf("Tenant ID: %q; want: %q", payload.TenantID, tenantID) } return nil } func logFatal(err error) { if err != nil { log.Fatal(err) } } golang-google-firebase-go-4.18.0/auth/email_action_links.go000066400000000000000000000113701505612111400236200ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/json" "errors" "fmt" "net/url" ) // ActionCodeSettings specifies the required continue/state URL with optional Android and iOS settings. Used when // invoking the email action link generation APIs. type ActionCodeSettings struct { URL string `json:"continueUrl"` HandleCodeInApp bool `json:"canHandleCodeInApp"` IOSBundleID string `json:"iOSBundleId,omitempty"` AndroidPackageName string `json:"androidPackageName,omitempty"` AndroidMinimumVersion string `json:"androidMinimumVersion,omitempty"` AndroidInstallApp bool `json:"androidInstallApp,omitempty"` LinkDomain string `json:"linkDomain,omitempty"` // Deprecated: Use LinkDomain instead. DynamicLinkDomain string `json:"dynamicLinkDomain,omitempty"` } func (settings *ActionCodeSettings) toMap() (map[string]interface{}, error) { if settings.URL == "" { return nil, errors.New("URL must not be empty") } url, err := url.Parse(settings.URL) if err != nil || url.Scheme == "" || url.Host == "" { return nil, fmt.Errorf("malformed url string: %q", settings.URL) } if settings.AndroidMinimumVersion != "" || settings.AndroidInstallApp { if settings.AndroidPackageName == "" { return nil, errors.New("Android package name is required when specifying other Android settings") } } b, err := json.Marshal(settings) if err != nil { return nil, err } var result map[string]interface{} if err := json.Unmarshal(b, &result); err != nil { return nil, err } return result, nil } type linkType string const ( emailLinkSignIn linkType = "EMAIL_SIGNIN" emailVerification linkType = "VERIFY_EMAIL" passwordReset linkType = "PASSWORD_RESET" ) // EmailVerificationLink generates the out-of-band email action link for email verification flows for the specified // email address. func (c *baseClient) EmailVerificationLink(ctx context.Context, email string) (string, error) { return c.EmailVerificationLinkWithSettings(ctx, email, nil) } // EmailVerificationLinkWithSettings generates the out-of-band email action link for email verification flows for the // specified email address, using the action code settings provided. func (c *baseClient) EmailVerificationLinkWithSettings( ctx context.Context, email string, settings *ActionCodeSettings) (string, error) { return c.generateEmailActionLink(ctx, emailVerification, email, settings) } // PasswordResetLink generates the out-of-band email action link for password reset flows for the specified email // address. func (c *baseClient) PasswordResetLink(ctx context.Context, email string) (string, error) { return c.PasswordResetLinkWithSettings(ctx, email, nil) } // PasswordResetLinkWithSettings generates the out-of-band email action link for password reset flows for the // specified email address, using the action code settings provided. func (c *baseClient) PasswordResetLinkWithSettings( ctx context.Context, email string, settings *ActionCodeSettings) (string, error) { return c.generateEmailActionLink(ctx, passwordReset, email, settings) } // EmailSignInLink generates the out-of-band email action link for email link sign-in flows, using the action // code settings provided. func (c *baseClient) EmailSignInLink( ctx context.Context, email string, settings *ActionCodeSettings) (string, error) { return c.generateEmailActionLink(ctx, emailLinkSignIn, email, settings) } func (c *baseClient) generateEmailActionLink( ctx context.Context, linkType linkType, email string, settings *ActionCodeSettings) (string, error) { if email == "" { return "", errors.New("email must not be empty") } if linkType == emailLinkSignIn && settings == nil { return "", errors.New("ActionCodeSettings must not be nil when generating sign-in links") } payload := map[string]interface{}{ "requestType": linkType, "email": email, "returnOobLink": true, } if settings != nil { settingsMap, err := settings.toMap() if err != nil { return "", err } for k, v := range settingsMap { payload[k] = v } } var result struct { OOBLink string `json:"oobLink"` } _, err := c.post(ctx, "/accounts:sendOobCode", payload, &result) return result.OOBLink, err } golang-google-firebase-go-4.18.0/auth/email_action_links_test.go000066400000000000000000000230651505612111400246630ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/json" "fmt" "net/http" "reflect" "testing" "firebase.google.com/go/v4/errorutils" ) const ( testActionLink = "https://test.link" testActionLinkFormat = `{"oobLink": %q}` testEmail = "user@domain.com" ) var testActionLinkResponse = []byte(fmt.Sprintf(testActionLinkFormat, testActionLink)) var testActionCodeSettings = &ActionCodeSettings{ URL: "https://example.dynamic.link", HandleCodeInApp: true, LinkDomain: "hosted.page.link", DynamicLinkDomain: "custom.page.link", IOSBundleID: "com.example.ios", AndroidPackageName: "com.example.android", AndroidInstallApp: true, AndroidMinimumVersion: "6", } var testActionCodeSettingsMap = map[string]interface{}{ "continueUrl": "https://example.dynamic.link", "canHandleCodeInApp": true, "linkDomain": "hosted.page.link", "dynamicLinkDomain": "custom.page.link", "iOSBundleId": "com.example.ios", "androidPackageName": "com.example.android", "androidInstallApp": true, "androidMinimumVersion": "6", } var invalidActionCodeSettings = []struct { name string settings *ActionCodeSettings want string }{ { "no-url", &ActionCodeSettings{}, "URL must not be empty", }, { "malformed-url", &ActionCodeSettings{ URL: "not a url", }, `malformed url string: "not a url"`, }, { "no-android-package-1", &ActionCodeSettings{ URL: "https://example.dynamic.link", AndroidInstallApp: true, }, "Android package name is required when specifying other Android settings", }, { "no-android-package-2", &ActionCodeSettings{ URL: "https://example.dynamic.link", AndroidMinimumVersion: "6", }, "Android package name is required when specifying other Android settings", }, } func TestEmailVerificationLink(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() link, err := s.Client.EmailVerificationLink(context.Background(), testEmail) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("EmailVerificationLink() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "VERIFY_EMAIL", "email": testEmail, "returnOobLink": true, } if err := checkActionLinkRequest(want, s); err != nil { t.Fatalf("EmailVerificationLink() %v", err) } } func TestEmailVerificationLinkWithSettings(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() link, err := s.Client.EmailVerificationLinkWithSettings(context.Background(), testEmail, testActionCodeSettings) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("EmailVerificationLinkWithSettings() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "VERIFY_EMAIL", "email": testEmail, "returnOobLink": true, } for k, v := range testActionCodeSettingsMap { want[k] = v } if err := checkActionLinkRequest(want, s); err != nil { t.Fatalf("EmailVerificationLinkWithSettings() %v", err) } } func TestPasswordResetLink(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() link, err := s.Client.PasswordResetLink(context.Background(), testEmail) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("PasswordResetLink() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "PASSWORD_RESET", "email": testEmail, "returnOobLink": true, } if err := checkActionLinkRequest(want, s); err != nil { t.Fatalf("PasswordResetLink() %v", err) } } func TestPasswordResetLinkWithSettings(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() link, err := s.Client.PasswordResetLinkWithSettings(context.Background(), testEmail, testActionCodeSettings) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("PasswordResetLinkWithSettings() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "PASSWORD_RESET", "email": testEmail, "returnOobLink": true, } for k, v := range testActionCodeSettingsMap { want[k] = v } if err := checkActionLinkRequest(want, s); err != nil { t.Fatalf("PasswordResetLinkWithSettings() %v", err) } } func TestPasswordResetLinkWithSettingsNonExistingUser(t *testing.T) { resp := `{ "error": { "message": "EMAIL_NOT_FOUND" } }` s := echoServer([]byte(resp), t) defer s.Close() s.Status = http.StatusBadRequest link, err := s.Client.PasswordResetLinkWithSettings(context.Background(), testEmail, testActionCodeSettings) if link != "" || err == nil { t.Errorf("PasswordResetLinkWithSettings() = (%q, %v); want = (%q, error)", link, err, "") } want := "no user record found for the given email" if err.Error() != want || !IsEmailNotFound(err) || !errorutils.IsNotFound(err) { t.Errorf("PasswordResetLinkWithSettings() error = %v; want = %q", err, want) } } func TestEmailSignInLink(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() link, err := s.Client.EmailSignInLink(context.Background(), testEmail, testActionCodeSettings) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("EmailSignInLink() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "EMAIL_SIGNIN", "email": testEmail, "returnOobLink": true, } for k, v := range testActionCodeSettingsMap { want[k] = v } if err := checkActionLinkRequest(want, s); err != nil { t.Fatalf("EmailSignInLink() %v", err) } } func TestEmailActionLinkNoEmail(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } if _, err := client.EmailVerificationLink(context.Background(), ""); err == nil { t.Errorf("EmailVerificationLink('') = nil; want error") } if _, err := client.PasswordResetLink(context.Background(), ""); err == nil { t.Errorf("PasswordResetLink('') = nil; want error") } if _, err := client.EmailSignInLink(context.Background(), "", testActionCodeSettings); err == nil { t.Errorf("EmailSignInLink('') = nil; want error") } } func TestEmailVerificationLinkInvalidSettings(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } for _, tc := range invalidActionCodeSettings { _, err := client.EmailVerificationLinkWithSettings(context.Background(), testEmail, tc.settings) if err == nil || err.Error() != tc.want { t.Errorf("EmailVerificationLinkWithSettings(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestPasswordResetLinkInvalidSettings(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } for _, tc := range invalidActionCodeSettings { _, err := client.PasswordResetLinkWithSettings(context.Background(), testEmail, tc.settings) if err == nil || err.Error() != tc.want { t.Errorf("PasswordResetLinkWithSettings(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestEmailSignInLinkInvalidSettings(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } for _, tc := range invalidActionCodeSettings { _, err := client.EmailSignInLink(context.Background(), testEmail, tc.settings) if err == nil || err.Error() != tc.want { t.Errorf("EmailSignInLink(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestEmailSignInLinkNoSettings(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } _, err := client.EmailSignInLink(context.Background(), testEmail, nil) if err == nil { t.Errorf("EmailSignInLink(nil) = %v; want = error", err) } } func TestEmailVerificationLinkError(t *testing.T) { cases := map[string]func(error) bool{ "UNAUTHORIZED_DOMAIN": IsUnauthorizedContinueURI, "INVALID_DYNAMIC_LINK_DOMAIN": IsInvalidDynamicLinkDomain, "INVALID_HOSTING_LINK_DOMAIN": IsInvalidHostingLinkDomain, } s := echoServer(testActionLinkResponse, t) defer s.Close() s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError for code, check := range cases { resp := fmt.Sprintf(`{"error": {"message": %q}}`, code) s.Resp = []byte(resp) _, err := s.Client.EmailVerificationLink(context.Background(), testEmail) if err == nil || !check(err) { t.Errorf("EmailVerificationLink(%q) = %v; want = %q", code, err, serverError[code]) } } } func checkActionLinkRequest(want map[string]interface{}, s *mockAuthServer) error { wantURL := "/projects/mock-project-id/accounts:sendOobCode" return checkActionLinkRequestWithURL(want, wantURL, s) } func checkActionLinkRequestWithURL(want map[string]interface{}, wantURL string, s *mockAuthServer) error { req := s.Req[0] if req.Method != http.MethodPost { return fmt.Errorf("Method = %q; want = %q", req.Method, http.MethodPatch) } if req.URL.Path != wantURL { return fmt.Errorf("URL = %q; want = %q", req.URL.Path, wantURL) } var got map[string]interface{} if err := json.Unmarshal(s.Rbody, &got); err != nil { return err } if !reflect.DeepEqual(got, want) { return fmt.Errorf("Body = %#v; want = %#v", got, want) } return nil } golang-google-firebase-go-4.18.0/auth/export_users.go000066400000000000000000000064701505612111400225430ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "fmt" "net/http" "net/url" "strconv" "firebase.google.com/go/v4/internal" "google.golang.org/api/iterator" ) const maxReturnedResults = 1000 // Users returns an iterator over Users. // // If nextPageToken is empty, the iterator will start at the beginning. // If the nextPageToken is not empty, the iterator starts after the token. func (c *baseClient) Users(ctx context.Context, nextPageToken string) *UserIterator { it := &UserIterator{ ctx: ctx, client: c, } it.pageInfo, it.nextFunc = iterator.NewPageInfo( it.fetch, func() int { return len(it.users) }, func() interface{} { b := it.users; it.users = nil; return b }) it.pageInfo.MaxSize = maxReturnedResults it.pageInfo.Token = nextPageToken return it } // UserIterator is an iterator over Users. // // Also see: https://github.com/GoogleCloudPlatform/google-cloud-go/wiki/Iterator-Guidelines type UserIterator struct { client *baseClient ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo users []*ExportedUserRecord } // PageInfo supports pagination. See the google.golang.org/api/iterator package for details. // Page size can be determined by the NewPager(...) function described there. func (it *UserIterator) PageInfo() *iterator.PageInfo { return it.pageInfo } // Next returns the next result. Its second return value is [iterator.Done] if // there are no more results. Once Next returns [iterator.Done], all subsequent // calls will return [iterator.Done]. func (it *UserIterator) Next() (*ExportedUserRecord, error) { if err := it.nextFunc(); err != nil { return nil, err } user := it.users[0] it.users = it.users[1:] return user, nil } func (it *UserIterator) fetch(pageSize int, pageToken string) (string, error) { query := make(url.Values) query.Set("maxResults", strconv.Itoa(pageSize)) if pageToken != "" { query.Set("nextPageToken", pageToken) } url, err := it.client.makeUserMgtURL(fmt.Sprintf("/accounts:batchGet?%s", query.Encode())) if err != nil { return "", err } req := &internal.Request{ Method: http.MethodGet, URL: url, } var parsed struct { Users []userQueryResponse `json:"users"` NextPageToken string `json:"nextPageToken"` } _, err = it.client.httpClient.DoAndUnmarshal(it.ctx, req, &parsed) if err != nil { return "", err } for _, u := range parsed.Users { eu, err := u.makeExportedUserRecord() if err != nil { return "", err } it.users = append(it.users, eu) } it.pageInfo.Token = parsed.NextPageToken return parsed.NextPageToken, nil } // ExportedUserRecord is the returned user value used when listing all the users. type ExportedUserRecord struct { *UserRecord PasswordHash string PasswordSalt string } golang-google-firebase-go-4.18.0/auth/hash/000077500000000000000000000000001505612111400203665ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/auth/hash/hash.go000066400000000000000000000216421505612111400216450ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package hash contains a collection of password hash algorithms that can be used with the // auth.ImportUsers() API. Refer to https://firebase.google.com/docs/auth/admin/import-users for // more details about supported hash algorithms. package hash import ( "encoding/base64" "errors" "fmt" "firebase.google.com/go/v4/internal" ) // InputOrderType specifies the order in which users' passwords/salts are hashed type InputOrderType int // Available InputOrderType values const ( InputOrderUnspecified InputOrderType = iota InputOrderSaltFirst InputOrderPasswordFirst ) // Bcrypt represents the BCRYPT hash algorithm. // // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_bcrypt_hashed_passwords // for more details. type Bcrypt struct{} // Config returns the validated hash configuration. func (b Bcrypt) Config() (internal.HashConfig, error) { return internal.HashConfig{"hashAlgorithm": "BCRYPT"}, nil } // StandardScrypt represents the standard scrypt hash algorithm. // // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_standard_scrypt_hashed_passwords // for more details. type StandardScrypt struct { BlockSize int DerivedKeyLength int MemoryCost int Parallelization int } // Config returns the validated hash configuration. func (s StandardScrypt) Config() (internal.HashConfig, error) { return internal.HashConfig{ "hashAlgorithm": "STANDARD_SCRYPT", "dkLen": s.DerivedKeyLength, "blockSize": s.BlockSize, "parallelization": s.Parallelization, "cpuMemCost": s.MemoryCost, }, nil } // Scrypt represents the scrypt hash algorithm. // // This is the modified scrypt used by Firebase Auth (https://github.com/firebase/scrypt). // Rounds must be between 1 and 8, and the MemoryCost must be between 1 and 14. Key is required. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_firebase_scrypt_hashed_passwords // for more details. type Scrypt struct { Key []byte SaltSeparator []byte Rounds int MemoryCost int } // Config returns the validated hash configuration. func (s Scrypt) Config() (internal.HashConfig, error) { if len(s.Key) == 0 { return nil, errors.New("signer key not specified") } if s.Rounds < 1 || s.Rounds > 8 { return nil, errors.New("rounds must be between 1 and 8") } if s.MemoryCost < 1 || s.MemoryCost > 14 { return nil, errors.New("memory cost must be between 1 and 14") } return internal.HashConfig{ "hashAlgorithm": "SCRYPT", "signerKey": base64.RawURLEncoding.EncodeToString(s.Key), "saltSeparator": base64.RawURLEncoding.EncodeToString(s.SaltSeparator), "rounds": s.Rounds, "memoryCost": s.MemoryCost, }, nil } // HMACMD5 represents the HMAC SHA512 hash algorithm. // // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_hmac_hashed_passwords // for more details. Key is required. type HMACMD5 struct { Key []byte InputOrder InputOrderType } // Config returns the validated hash configuration. func (h HMACMD5) Config() (internal.HashConfig, error) { return hmacConfig("HMAC_MD5", h.Key, h.InputOrder) } // HMACSHA1 represents the HMAC SHA512 hash algorithm. // // Key is required. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_hmac_hashed_passwords // for more details. type HMACSHA1 struct { Key []byte InputOrder InputOrderType } // Config returns the validated hash configuration. func (h HMACSHA1) Config() (internal.HashConfig, error) { return hmacConfig("HMAC_SHA1", h.Key, h.InputOrder) } // HMACSHA256 represents the HMAC SHA512 hash algorithm. // // Key is required. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_hmac_hashed_passwords // for more details. type HMACSHA256 struct { Key []byte InputOrder InputOrderType } // Config returns the validated hash configuration. func (h HMACSHA256) Config() (internal.HashConfig, error) { return hmacConfig("HMAC_SHA256", h.Key, h.InputOrder) } // HMACSHA512 represents the HMAC SHA512 hash algorithm. // // Key is required. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_hmac_hashed_passwords // for more details. type HMACSHA512 struct { Key []byte InputOrder InputOrderType } // Config returns the validated hash configuration. func (h HMACSHA512) Config() (internal.HashConfig, error) { return hmacConfig("HMAC_SHA512", h.Key, h.InputOrder) } // MD5 represents the MD5 hash algorithm. // // Rounds must be between 0 and 8192. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_md5_sha_and_pbkdf_hashed_passwords // for more details. type MD5 struct { Rounds int InputOrder InputOrderType } // Config returns the validated hash configuration. func (h MD5) Config() (internal.HashConfig, error) { return basicConfig("MD5", h.Rounds, h.InputOrder) } // PBKDF2SHA256 represents the PBKDF2SHA256 hash algorithm. // // Rounds must be between 0 and 120000. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_md5_sha_and_pbkdf_hashed_passwords // for more details. type PBKDF2SHA256 struct { Rounds int } // Config returns the validated hash configuration. func (h PBKDF2SHA256) Config() (internal.HashConfig, error) { return basicConfig("PBKDF2_SHA256", h.Rounds, InputOrderUnspecified) } // PBKDFSHA1 represents the PBKDFSHA1 hash algorithm. // // Rounds must be between 0 and 120000. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_md5_sha_and_pbkdf_hashed_passwords // for more details. type PBKDFSHA1 struct { Rounds int } // Config returns the validated hash configuration. func (h PBKDFSHA1) Config() (internal.HashConfig, error) { return basicConfig("PBKDF_SHA1", h.Rounds, InputOrderUnspecified) } // SHA1 represents the SHA1 hash algorithm. // // Rounds must be between 1 and 8192. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_md5_sha_and_pbkdf_hashed_passwords // for more details. type SHA1 struct { Rounds int InputOrder InputOrderType } // Config returns the validated hash configuration. func (h SHA1) Config() (internal.HashConfig, error) { return basicConfig("SHA1", h.Rounds, h.InputOrder) } // SHA256 represents the SHA256 hash algorithm. // // Rounds must be between 1 and 8192. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_md5_sha_and_pbkdf_hashed_passwords // for more details. type SHA256 struct { Rounds int InputOrder InputOrderType } // Config returns the validated hash configuration. func (h SHA256) Config() (internal.HashConfig, error) { return basicConfig("SHA256", h.Rounds, h.InputOrder) } // SHA512 represents the SHA512 hash algorithm. // // Rounds must be between 1 and 8192. // Refer to https://firebase.google.com/docs/auth/admin/import-users#import_users_with_md5_sha_and_pbkdf_hashed_passwords // for more details. type SHA512 struct { Rounds int InputOrder InputOrderType } // Config returns the validated hash configuration. func (h SHA512) Config() (internal.HashConfig, error) { return basicConfig("SHA512", h.Rounds, h.InputOrder) } func hmacConfig(name string, key []byte, order InputOrderType) (internal.HashConfig, error) { if len(key) == 0 { return nil, errors.New("signer key not specified") } conf := internal.HashConfig{ "hashAlgorithm": name, "signerKey": base64.RawURLEncoding.EncodeToString(key), } if order == InputOrderSaltFirst { conf["passwordHashOrder"] = "SALT_AND_PASSWORD" } else if order == InputOrderPasswordFirst { conf["passwordHashOrder"] = "PASSWORD_AND_SALT" } return conf, nil } func basicConfig(name string, rounds int, order InputOrderType) (internal.HashConfig, error) { minRounds := 0 maxRounds := 120000 switch name { case "MD5": maxRounds = 8192 case "SHA1", "SHA256", "SHA512": minRounds = 1 maxRounds = 8192 } if rounds < minRounds || maxRounds < rounds { return nil, fmt.Errorf("rounds must be between %d and %d", minRounds, maxRounds) } conf := internal.HashConfig{ "hashAlgorithm": name, "rounds": rounds, } if order == InputOrderSaltFirst { conf["passwordHashOrder"] = "SALT_AND_PASSWORD" } else if order == InputOrderPasswordFirst { conf["passwordHashOrder"] = "PASSWORD_AND_SALT" } return conf, nil } golang-google-firebase-go-4.18.0/auth/hash/hash_test.go000066400000000000000000000264221505612111400227050ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package hash import ( "encoding/base64" "reflect" "testing" "firebase.google.com/go/v4/auth" "firebase.google.com/go/v4/internal" ) var ( signerKey = []byte("key") saltSeparator = []byte("sep") ) var validHashes = []struct { alg auth.UserImportHash want internal.HashConfig }{ { alg: Bcrypt{}, want: internal.HashConfig{"hashAlgorithm": "BCRYPT"}, }, { alg: StandardScrypt{ BlockSize: 1, DerivedKeyLength: 2, Parallelization: 3, MemoryCost: 4, }, want: internal.HashConfig{ "hashAlgorithm": "STANDARD_SCRYPT", "blockSize": 1, "dkLen": 2, "parallelization": 3, "cpuMemCost": 4, }, }, { alg: Scrypt{ Key: signerKey, SaltSeparator: saltSeparator, Rounds: 8, MemoryCost: 14, }, want: internal.HashConfig{ "hashAlgorithm": "SCRYPT", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "saltSeparator": base64.RawURLEncoding.EncodeToString(saltSeparator), "rounds": 8, "memoryCost": 14, }, }, { alg: HMACMD5{Key: signerKey}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_MD5", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACSHA1{Key: signerKey}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA1", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACSHA256{Key: signerKey}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA256", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACSHA512{Key: signerKey}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA512", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: MD5{Rounds: 0}, want: internal.HashConfig{ "hashAlgorithm": "MD5", "rounds": 0, }, }, { alg: MD5{Rounds: 8192}, want: internal.HashConfig{ "hashAlgorithm": "MD5", "rounds": 8192, }, }, { alg: SHA1{Rounds: 1}, want: internal.HashConfig{ "hashAlgorithm": "SHA1", "rounds": 1, }, }, { alg: SHA1{Rounds: 8192}, want: internal.HashConfig{ "hashAlgorithm": "SHA1", "rounds": 8192, }, }, { alg: SHA256{Rounds: 1}, want: internal.HashConfig{ "hashAlgorithm": "SHA256", "rounds": 1, }, }, { alg: SHA256{Rounds: 8192}, want: internal.HashConfig{ "hashAlgorithm": "SHA256", "rounds": 8192, }, }, { alg: SHA512{Rounds: 1}, want: internal.HashConfig{ "hashAlgorithm": "SHA512", "rounds": 1, }, }, { alg: SHA512{Rounds: 8192}, want: internal.HashConfig{ "hashAlgorithm": "SHA512", "rounds": 8192, }, }, { alg: PBKDFSHA1{Rounds: 0}, want: internal.HashConfig{ "hashAlgorithm": "PBKDF_SHA1", "rounds": 0, }, }, { alg: PBKDFSHA1{Rounds: 120000}, want: internal.HashConfig{ "hashAlgorithm": "PBKDF_SHA1", "rounds": 120000, }, }, { alg: PBKDF2SHA256{Rounds: 0}, want: internal.HashConfig{ "hashAlgorithm": "PBKDF2_SHA256", "rounds": 0, }, }, { alg: PBKDF2SHA256{Rounds: 120000}, want: internal.HashConfig{ "hashAlgorithm": "PBKDF2_SHA256", "rounds": 120000, }, }, } var invalidHashes = []struct { name string alg auth.UserImportHash }{ { name: "SCRYPT: no signer key", alg: Scrypt{ SaltSeparator: saltSeparator, Rounds: 8, MemoryCost: 14, }, }, { name: "SCRYPT: low rounds", alg: Scrypt{ Key: signerKey, SaltSeparator: saltSeparator, MemoryCost: 14, }, }, { name: "SCRYPT: high rounds", alg: Scrypt{ Key: signerKey, SaltSeparator: saltSeparator, Rounds: 9, MemoryCost: 14, }, }, { name: "SCRYPT: low memory cost", alg: Scrypt{ Key: signerKey, SaltSeparator: saltSeparator, Rounds: 8, }, }, { name: "SCRYPT: high memory cost", alg: Scrypt{ Key: signerKey, SaltSeparator: saltSeparator, Rounds: 8, MemoryCost: 15, }, }, { name: "HMAC_MD5: no signer key", alg: HMACMD5{}, }, { name: "HMAC_SHA1: no signer key", alg: HMACSHA1{}, }, { name: "HMAC_SHA256: no signer key", alg: HMACSHA256{}, }, { name: "HMAC_SHA512: no signer key", alg: HMACSHA512{}, }, { name: "MD5: rounds too low", alg: MD5{Rounds: -1}, }, { name: "SHA1: rounds too low", alg: SHA1{Rounds: 0}, }, { name: "SHA256: rounds too low", alg: SHA256{Rounds: 0}, }, { name: "SHA512: rounds too low", alg: SHA512{Rounds: 0}, }, { name: "PBKDFSHA1: rounds too low", alg: PBKDFSHA1{Rounds: -1}, }, { name: "PBKDF2SHA256: rounds too low", alg: PBKDF2SHA256{Rounds: -1}, }, { name: "MD5: rounds too high", alg: MD5{Rounds: 8193}, }, { name: "SHA1: rounds too high", alg: SHA1{Rounds: 8193}, }, { name: "SHA256: rounds too high", alg: SHA256{Rounds: 8193}, }, { name: "SHA512: rounds too high", alg: SHA512{Rounds: 8193}, }, { name: "PBKDFSHA1: rounds too high", alg: PBKDFSHA1{Rounds: 120001}, }, { name: "PBKDF2SHA256: rounds too high", alg: PBKDF2SHA256{Rounds: 120001}, }, } func TestValidHash(t *testing.T) { for idx, tc := range validHashes { got, err := tc.alg.Config() if err != nil { t.Errorf("[%d] Config() = %v", idx, err) } else if !reflect.DeepEqual(got, tc.want) { t.Errorf("[%d] Config() = %#v; want = %#v", idx, got, tc.want) } } } func TestInvalidHash(t *testing.T) { for _, tc := range invalidHashes { got, err := tc.alg.Config() if got != nil || err == nil { t.Errorf("%s; Config() = (%v, %v); want = (nil, error)", tc.name, got, err) } } } var validHashesOrder = []struct { alg auth.UserImportHash want internal.HashConfig }{ { alg: HMACMD5{Key: signerKey, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_MD5", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACMD5{Key: signerKey, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_MD5", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: HMACMD5{Key: signerKey, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_MD5", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: HMACSHA1{Key: signerKey, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA1", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACSHA1{Key: signerKey, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA1", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: HMACSHA1{Key: signerKey, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA1", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: HMACSHA256{Key: signerKey, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA256", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACSHA256{Key: signerKey, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA256", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: HMACSHA256{Key: signerKey, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA256", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: HMACSHA512{Key: signerKey, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA512", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), }, }, { alg: HMACSHA512{Key: signerKey, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA512", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: HMACSHA512{Key: signerKey, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "HMAC_SHA512", "signerKey": base64.RawURLEncoding.EncodeToString(signerKey), "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: SHA1{Rounds: 1, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "SHA1", "rounds": 1, }, }, { alg: SHA1{Rounds: 1, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "SHA1", "rounds": 1, "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: SHA1{Rounds: 1, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "SHA1", "rounds": 1, "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: SHA256{Rounds: 1, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "SHA256", "rounds": 1, }, }, { alg: SHA256{Rounds: 1, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "SHA256", "rounds": 1, "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: SHA256{Rounds: 1, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "SHA256", "rounds": 1, "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: SHA512{Rounds: 1, InputOrder: InputOrderUnspecified}, want: internal.HashConfig{ "hashAlgorithm": "SHA512", "rounds": 1, }, }, { alg: SHA512{Rounds: 1, InputOrder: InputOrderSaltFirst}, want: internal.HashConfig{ "hashAlgorithm": "SHA512", "rounds": 1, "passwordHashOrder": "SALT_AND_PASSWORD", }, }, { alg: SHA512{Rounds: 1, InputOrder: InputOrderPasswordFirst}, want: internal.HashConfig{ "hashAlgorithm": "SHA512", "rounds": 1, "passwordHashOrder": "PASSWORD_AND_SALT", }, }, { alg: SHA512{Rounds: 8192}, want: internal.HashConfig{ "hashAlgorithm": "SHA512", "rounds": 8192, }, }, } func TestHashOrder(t *testing.T) { for idx, tc := range validHashesOrder { got, err := tc.alg.Config() if err != nil { t.Errorf("[%d] Config() = %v", idx, err) } else if !reflect.DeepEqual(got, tc.want) { t.Errorf("[%d] Config() = %#v; want = %#v", idx, got, tc.want) } } } golang-google-firebase-go-4.18.0/auth/import_users.go000066400000000000000000000165351505612111400225370ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/base64" "errors" "fmt" "firebase.google.com/go/v4/internal" ) const maxImportUsers = 1000 // UserImportOption is an option for the ImportUsers() function. type UserImportOption interface { applyTo(req map[string]interface{}) error } // UserImportResult represents the result of an ImportUsers() call. type UserImportResult struct { SuccessCount int FailureCount int Errors []*ErrorInfo } // ErrorInfo represents an error encountered while importing a single user account. // // The Index field corresponds to the index of the failed user in the users array that was passed // to ImportUsers(). type ErrorInfo struct { Index int Reason string } // ImportUsers imports an array of users to Firebase Auth. // // No more than 1000 users can be imported in a single call. If at least one user specifies a // password, a UserImportHash must be specified as an option. func (c *baseClient) ImportUsers( ctx context.Context, users []*UserToImport, opts ...UserImportOption) (*UserImportResult, error) { if len(users) == 0 { return nil, errors.New("users list must not be empty") } if len(users) > maxImportUsers { return nil, fmt.Errorf("users list must not contain more than %d elements", maxImportUsers) } var validatedUsers []map[string]interface{} hashRequired := false for _, u := range users { vu, err := u.validatedUserInfo() if err != nil { return nil, err } if pw, ok := vu["passwordHash"]; ok && pw != "" { hashRequired = true } validatedUsers = append(validatedUsers, vu) } req := map[string]interface{}{ "users": validatedUsers, } for _, opt := range opts { if err := opt.applyTo(req); err != nil { return nil, err } } if hashRequired { if algo, ok := req["hashAlgorithm"]; !ok || algo == "" { return nil, errors.New("hash algorithm option is required to import users with passwords") } } var parsed struct { Error []struct { Index int `json:"index"` Message string `json:"message"` } `json:"error,omitempty"` } _, err := c.post(ctx, "/accounts:batchCreate", req, &parsed) if err != nil { return nil, err } result := &UserImportResult{ SuccessCount: len(users) - len(parsed.Error), FailureCount: len(parsed.Error), } for _, e := range parsed.Error { result.Errors = append(result.Errors, &ErrorInfo{ Index: int(e.Index), Reason: e.Message, }) } return result, nil } // UserToImport represents a user account that can be bulk imported into Firebase Auth. type UserToImport struct { params map[string]interface{} } // UID setter. This field is required. func (u *UserToImport) UID(uid string) *UserToImport { return u.set("localId", uid) } // Email setter. func (u *UserToImport) Email(email string) *UserToImport { return u.set("email", email) } // DisplayName setter. func (u *UserToImport) DisplayName(displayName string) *UserToImport { return u.set("displayName", displayName) } // PhotoURL setter. func (u *UserToImport) PhotoURL(url string) *UserToImport { return u.set("photoUrl", url) } // PhoneNumber setter. func (u *UserToImport) PhoneNumber(phoneNumber string) *UserToImport { return u.set("phoneNumber", phoneNumber) } // Metadata setter. func (u *UserToImport) Metadata(metadata *UserMetadata) *UserToImport { if metadata.CreationTimestamp != 0 { u.set("createdAt", metadata.CreationTimestamp) } if metadata.LastLogInTimestamp != 0 { u.set("lastLoginAt", metadata.LastLogInTimestamp) } return u } // CustomClaims setter. func (u *UserToImport) CustomClaims(claims map[string]interface{}) *UserToImport { return u.set("customClaims", claims) } // Disabled setter. func (u *UserToImport) Disabled(disabled bool) *UserToImport { return u.set("disabled", disabled) } // EmailVerified setter. func (u *UserToImport) EmailVerified(emailVerified bool) *UserToImport { return u.set("emailVerified", emailVerified) } // PasswordHash setter. When set, a UserImportHash must be specified as an option to call // ImportUsers(). func (u *UserToImport) PasswordHash(password []byte) *UserToImport { return u.set("passwordHash", base64.RawURLEncoding.EncodeToString(password)) } // PasswordSalt setter. func (u *UserToImport) PasswordSalt(salt []byte) *UserToImport { return u.set("salt", base64.RawURLEncoding.EncodeToString(salt)) } func (u *UserToImport) set(key string, value interface{}) *UserToImport { if u.params == nil { u.params = make(map[string]interface{}) } u.params[key] = value return u } // UserProvider represents a user identity provider. // // One or more user providers can be specified for each user when importing in bulk. // See UserToImport type. type UserProvider struct { UID string `json:"rawId"` ProviderID string `json:"providerId"` Email string `json:"email,omitempty"` DisplayName string `json:"displayName,omitempty"` PhotoURL string `json:"photoUrl,omitempty"` } // ProviderData setter. func (u *UserToImport) ProviderData(providers []*UserProvider) *UserToImport { return u.set("providerUserInfo", providers) } func (u *UserToImport) validatedUserInfo() (map[string]interface{}, error) { if len(u.params) == 0 { return nil, fmt.Errorf("no parameters are set on the user to import") } info := make(map[string]interface{}) for k, v := range u.params { info[k] = v } if err := validateUID(info["localId"].(string)); err != nil { return nil, err } if email, ok := info["email"]; ok { if err := validateEmail(email.(string)); err != nil { return nil, err } } if phone, ok := info["phoneNumber"]; ok { if err := validatePhone(phone.(string)); err != nil { return nil, err } } if claims, ok := info["customClaims"]; ok { claimsMap := claims.(map[string]interface{}) if len(claimsMap) > 0 { cc, err := marshalCustomClaims(claimsMap) if err != nil { return nil, err } info["customAttributes"] = cc } delete(info, "customClaims") } if providers, ok := info["providerUserInfo"]; ok { for _, p := range providers.([]*UserProvider) { if err := validateProviderUserInfo(p); err != nil { return nil, err } } } return info, nil } // WithHash returns a UserImportOption that specifies a hash configuration. func WithHash(hash UserImportHash) UserImportOption { return withHash{hash} } // UserImportHash represents a hash algorithm and the associated configuration that can be used to // hash user passwords. // // A UserImportHash must be specified in the form of a UserImportOption when importing users with // passwords. See ImportUsers() and WithHash() functions. type UserImportHash interface { Config() (internal.HashConfig, error) } type withHash struct { hash UserImportHash } func (w withHash) applyTo(req map[string]interface{}) error { conf, err := w.hash.Config() if err != nil { return err } for k, v := range conf { req[k] = v } return nil } golang-google-firebase-go-4.18.0/auth/multi_factor_config_mgt.go000066400000000000000000000063531505612111400246650ustar00rootroot00000000000000// Copyright 2023 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "fmt" ) // ProviderConfig represents a multi-factor auth provider configuration. // Currently, only TOTP is supported. type ProviderConfig struct { // The state of multi-factor configuration, whether it's enabled or disabled. State MultiFactorConfigState `json:"state"` // TOTPProviderConfig holds the TOTP (time-based one-time password) configuration that is used in second factor authentication. TOTPProviderConfig *TOTPProviderConfig `json:"totpProviderConfig,omitempty"` } // TOTPProviderConfig represents configuration settings for TOTP second factor auth. type TOTPProviderConfig struct { // The number of adjacent intervals used by TOTP. AdjacentIntervals int `json:"adjacentIntervals,omitempty"` } // MultiFactorConfigState represents whether the multi-factor configuration is enabled or disabled. type MultiFactorConfigState string // These constants represent the possible values for the MultiFactorConfigState type. const ( Enabled MultiFactorConfigState = "ENABLED" Disabled MultiFactorConfigState = "DISABLED" ) // MultiFactorConfig represents a multi-factor configuration for a tenant or project. // This can be used to define whether multi-factor authentication is enabled or disabled and the list of second factor challenges that are supported. type MultiFactorConfig struct { // A slice of pointers to ProviderConfig structs, each outlining the specific second factor authorization method. ProviderConfigs []*ProviderConfig `json:"providerConfigs,omitempty"` } func (mfa *MultiFactorConfig) validate() error { if mfa == nil { return nil } if len(mfa.ProviderConfigs) == 0 { return fmt.Errorf("\"ProviderConfigs\" must be a non-empty array of type \"ProviderConfig\"s") } for _, providerConfig := range mfa.ProviderConfigs { if providerConfig == nil { return fmt.Errorf("\"ProviderConfigs\" must be a non-empty array of type \"ProviderConfig\"s") } if err := providerConfig.validate(); err != nil { return err } } return nil } func (pvc *ProviderConfig) validate() error { if pvc.State == "" && pvc.TOTPProviderConfig == nil { return fmt.Errorf("\"ProviderConfig\" must be defined") } state := string(pvc.State) if state != string(Enabled) && state != string(Disabled) { return fmt.Errorf("\"ProviderConfig.State\" must be 'Enabled' or 'Disabled'") } return pvc.TOTPProviderConfig.validate() } func (tpvc *TOTPProviderConfig) validate() error { if tpvc == nil { return fmt.Errorf("\"TOTPProviderConfig\" must be defined") } if !(tpvc.AdjacentIntervals >= 1 && tpvc.AdjacentIntervals <= 10) { return fmt.Errorf("\"AdjacentIntervals\" must be an integer between 1 and 10 (inclusive)") } return nil } golang-google-firebase-go-4.18.0/auth/multi_factor_config_mgt_test.go000066400000000000000000000065311505612111400257220ustar00rootroot00000000000000// Copyright 2023 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "testing" ) func TestMultiFactorConfig(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{{ State: Disabled, TOTPProviderConfig: &TOTPProviderConfig{ AdjacentIntervals: 5, }, }}, } if err := mfa.validate(); err != nil { t.Errorf("MultiFactorConfig not valid") } } func TestMultiFactorConfigNoProviderConfigs(t *testing.T) { mfa := MultiFactorConfig{} want := "\"ProviderConfigs\" must be a non-empty array of type \"ProviderConfig\"s" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } func TestMultiFactorConfigNilProviderConfigs(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: nil, } want := "\"ProviderConfigs\" must be a non-empty array of type \"ProviderConfig\"s" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } func TestMultiFactorConfigNilProviderConfig(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{nil}, } want := "\"ProviderConfigs\" must be a non-empty array of type \"ProviderConfig\"s" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } func TestMultiFactorConfigUndefinedProviderConfig(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{{}}, } want := "\"ProviderConfig\" must be defined" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } func TestMultiFactorConfigInvalidProviderConfigState(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{{ State: "invalid", }}, } want := "\"ProviderConfig.State\" must be 'Enabled' or 'Disabled'" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } func TestMultiFactorConfigNilTOTPProviderConfig(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{{ State: Disabled, TOTPProviderConfig: nil, }}, } want := "\"TOTPProviderConfig\" must be defined" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } func TestMultiFactorConfigInvalidAdjacentIntervals(t *testing.T) { mfa := MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{{ State: Disabled, TOTPProviderConfig: &TOTPProviderConfig{ AdjacentIntervals: 11, }, }}, } want := "\"AdjacentIntervals\" must be an integer between 1 and 10 (inclusive)" if err := mfa.validate(); err.Error() != want { t.Errorf("MultiFactorConfig.validate(nil) = %v, want = %q", err, want) } } golang-google-firebase-go-4.18.0/auth/project_config_mgt.go000066400000000000000000000061301505612111400236340ustar00rootroot00000000000000// Copyright 2023 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "errors" "fmt" "net/http" "strings" "firebase.google.com/go/v4/internal" ) // ProjectConfig represents the properties to update on the provided project config. type ProjectConfig struct { MultiFactorConfig *MultiFactorConfig `json:"mfa,omitEmpty"` } func (base *baseClient) GetProjectConfig(ctx context.Context) (*ProjectConfig, error) { req := &internal.Request{ Method: http.MethodGet, URL: "/config", } var result ProjectConfig if _, err := base.makeRequest(ctx, req, &result); err != nil { return nil, err } return &result, nil } func (base *baseClient) UpdateProjectConfig(ctx context.Context, projectConfig *ProjectConfigToUpdate) (*ProjectConfig, error) { if projectConfig == nil { return nil, errors.New("project config must not be nil") } if err := projectConfig.validate(); err != nil { return nil, err } mask := projectConfig.params.UpdateMask() if len(mask) == 0 { return nil, errors.New("no parameters specified in the update request") } req := &internal.Request{ Method: http.MethodPatch, URL: "/config", Body: internal.NewJSONEntity(projectConfig.params), Opts: []internal.HTTPOption{ internal.WithQueryParam("updateMask", strings.Join(mask, ",")), }, } var result ProjectConfig if _, err := base.makeRequest(ctx, req, &result); err != nil { return nil, err } return &result, nil } // ProjectConfigToUpdate represents the options used to update the current project. type ProjectConfigToUpdate struct { params nestedMap } const ( multiFactorConfigProjectKey = "mfa" ) // MultiFactorConfig configures the project's multi-factor settings func (pc *ProjectConfigToUpdate) MultiFactorConfig(multiFactorConfig MultiFactorConfig) *ProjectConfigToUpdate { return pc.set(multiFactorConfigProjectKey, multiFactorConfig) } func (pc *ProjectConfigToUpdate) set(key string, value interface{}) *ProjectConfigToUpdate { pc.ensureParams().Set(key, value) return pc } func (pc *ProjectConfigToUpdate) ensureParams() nestedMap { if pc.params == nil { pc.params = make(nestedMap) } return pc.params } func (pc *ProjectConfigToUpdate) validate() error { req := make(map[string]interface{}) for k, v := range pc.params { req[k] = v } val, ok := req[multiFactorConfigProjectKey] if ok { multiFactorConfig, ok := val.(MultiFactorConfig) if !ok { return fmt.Errorf("invalid type for MultiFactorConfig: %s", req[multiFactorConfigProjectKey]) } if err := multiFactorConfig.validate(); err != nil { return err } } return nil } golang-google-firebase-go-4.18.0/auth/project_config_mgt_test.go000066400000000000000000000073221505612111400246770ustar00rootroot00000000000000// Copyright 2023 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/json" "fmt" "net/http" "reflect" "sort" "strings" "testing" "github.com/google/go-cmp/cmp" ) const projectConfigResponse = `{ "mfa": { "providerConfigs": [ { "state":"ENABLED", "totpProviderConfig":{ "adjacentIntervals":5 } } ] } }` var testProjectConfig = &ProjectConfig{ MultiFactorConfig: &MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{ { State: Enabled, TOTPProviderConfig: &TOTPProviderConfig{ AdjacentIntervals: 5, }, }, }, }, } func TestGetProjectConfig(t *testing.T) { s := echoServer([]byte(projectConfigResponse), t) defer s.Close() client := s.Client projectConfig, err := client.GetProjectConfig(context.Background()) if err != nil { t.Errorf("GetProjectConfig() = %v", err) } if !reflect.DeepEqual(projectConfig, testProjectConfig) { t.Errorf("GetProjectConfig() = %#v, want = %#v", projectConfig, testProjectConfig) } } func TestUpdateProjectConfig(t *testing.T) { s := echoServer([]byte(projectConfigResponse), t) defer s.Close() client := s.Client options := (&ProjectConfigToUpdate{}). MultiFactorConfig(*testProjectConfig.MultiFactorConfig) projectConfig, err := client.UpdateProjectConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(projectConfig, testProjectConfig) { t.Errorf("UpdateProjectConfig() = %#v; want = %#v", projectConfig, testProjectConfig) } wantBody := map[string]interface{}{ "mfa": map[string]interface{}{ "providerConfigs": []interface{}{ map[string]interface{}{ "state": "ENABLED", "totpProviderConfig": map[string]interface{}{ "adjacentIntervals": float64(5), }, }, }, }, } wantMask := []string{"mfa"} if err := checkUpdateProjectConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateProjectNilOptions(t *testing.T) { base := &baseClient{} want := "project config must not be nil" if _, err := base.UpdateProjectConfig(context.Background(), nil); err == nil || err.Error() != want { t.Errorf("UpdateProject(nil) = %v, want = %q", err, want) } } func checkUpdateProjectConfigRequest(s *mockAuthServer, wantBody interface{}, wantMask []string) error { req := s.Req[0] if req.Method != http.MethodPatch { return fmt.Errorf("UpdateProjectConfig() Method = %q; want = %q", req.Method, http.MethodPatch) } wantURL := "/projects/mock-project-id/config" if req.URL.Path != wantURL { return fmt.Errorf("UpdateProjectConfig() URL = %q; want = %q", req.URL.Path, wantURL) } queryParam := req.URL.Query().Get("updateMask") mask := strings.Split(queryParam, ",") sort.Strings(mask) if !reflect.DeepEqual(mask, wantMask) { return fmt.Errorf("UpdateProjectConfig() Query = %#v; want = %#v", mask, wantMask) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if diff := cmp.Diff(body, wantBody); diff != "" { fmt.Printf("UpdateProjectConfig() diff = %s", diff) } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("UpdateProjectConfig() Body = %#v; want = %#v", body, wantBody) } return nil } golang-google-firebase-go-4.18.0/auth/provider_config.go000066400000000000000000000743271505612111400231660ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "errors" "fmt" "net/http" "net/url" "strconv" "strings" "firebase.google.com/go/v4/internal" "google.golang.org/api/iterator" ) const ( maxConfigs = 100 idpEntityIDKey = "idpConfig.idpEntityId" ssoURLKey = "idpConfig.ssoUrl" signRequestKey = "idpConfig.signRequest" idpCertsKey = "idpConfig.idpCertificates" spEntityIDKey = "spConfig.spEntityId" callbackURIKey = "spConfig.callbackUri" clientIDKey = "clientId" clientSecretKey = "clientSecret" issuerKey = "issuer" displayNameKey = "displayName" enabledKey = "enabled" idTokenResponseTypeKey = "responseType.idToken" codeResponseTypeKey = "responseType.code" ) type nestedMap map[string]interface{} func (nm nestedMap) Get(key string) (interface{}, bool) { segments := strings.Split(key, ".") curr := map[string]interface{}(nm) for idx, segment := range segments { val, ok := curr[segment] if idx == len(segments)-1 || !ok { return val, ok } curr = val.(map[string]interface{}) } return nil, false } func (nm nestedMap) GetString(key string) (string, bool) { if val, ok := nm.Get(key); ok { return val.(string), true } return "", false } func (nm nestedMap) Set(key string, value interface{}) { segments := strings.Split(key, ".") curr := map[string]interface{}(nm) for idx, segment := range segments { if idx == len(segments)-1 { curr[segment] = value return } child, ok := curr[segment] if ok { curr = child.(map[string]interface{}) continue } newChild := make(map[string]interface{}) curr[segment] = newChild curr = newChild } } func (nm nestedMap) UpdateMask() []string { return buildMask(nm) } func buildMask(data map[string]interface{}) []string { var mask []string for k, v := range data { if child, ok := v.(map[string]interface{}); ok { childMask := buildMask(child) for _, item := range childMask { mask = append(mask, fmt.Sprintf("%s.%s", k, item)) } } else { mask = append(mask, k) } } return mask } // OIDCProviderConfig is the OIDC auth provider configuration. // See https://openid.net/specs/openid-connect-core-1_0-final.html. type OIDCProviderConfig struct { ID string DisplayName string Enabled bool ClientID string Issuer string ClientSecret string CodeResponseType bool IDTokenResponseType bool } // OIDCProviderConfigToCreate represents the options used to create a new OIDCProviderConfig. type OIDCProviderConfigToCreate struct { id string params nestedMap } // ID sets the provider ID of the new config. func (config *OIDCProviderConfigToCreate) ID(id string) *OIDCProviderConfigToCreate { config.id = id return config } // ClientID sets the client ID of the new config. func (config *OIDCProviderConfigToCreate) ClientID(clientID string) *OIDCProviderConfigToCreate { return config.set(clientIDKey, clientID) } // Issuer sets the issuer of the new config. func (config *OIDCProviderConfigToCreate) Issuer(issuer string) *OIDCProviderConfigToCreate { return config.set(issuerKey, issuer) } // DisplayName sets the DisplayName field of the new config. func (config *OIDCProviderConfigToCreate) DisplayName(name string) *OIDCProviderConfigToCreate { return config.set(displayNameKey, name) } // Enabled enables or disables the new config. func (config *OIDCProviderConfigToCreate) Enabled(enabled bool) *OIDCProviderConfigToCreate { return config.set(enabledKey, enabled) } // ClientSecret sets the client secret for the new provider. // This is required for the code flow. func (config *OIDCProviderConfigToCreate) ClientSecret(secret string) *OIDCProviderConfigToCreate { return config.set(clientSecretKey, secret) } // IDTokenResponseType sets whether to enable the ID token response flow for the new provider. // By default, this is enabled if no response type is specified. // Having both the code and ID token response flows is currently not supported. func (config *OIDCProviderConfigToCreate) IDTokenResponseType(enabled bool) *OIDCProviderConfigToCreate { return config.set(idTokenResponseTypeKey, enabled) } // CodeResponseType sets whether to enable the code response flow for the new provider. // By default, this is not enabled if no response type is specified. // A client secret must be set for this response type. // Having both the code and ID token response flows is currently not supported. func (config *OIDCProviderConfigToCreate) CodeResponseType(enabled bool) *OIDCProviderConfigToCreate { return config.set(codeResponseTypeKey, enabled) } func (config *OIDCProviderConfigToCreate) set(key string, value interface{}) *OIDCProviderConfigToCreate { if config.params == nil { config.params = make(nestedMap) } config.params.Set(key, value) return config } func (config *OIDCProviderConfigToCreate) buildRequest() (nestedMap, string, error) { if err := validateOIDCConfigID(config.id); err != nil { return nil, "", err } if len(config.params) == 0 { return nil, "", errors.New("no parameters specified in the create request") } if val, ok := config.params.GetString(clientIDKey); !ok || val == "" { return nil, "", errors.New("ClientID must not be empty") } if val, ok := config.params.GetString(issuerKey); !ok || val == "" { return nil, "", errors.New("Issuer must not be empty") } else if _, err := url.ParseRequestURI(val); err != nil { return nil, "", fmt.Errorf("failed to parse Issuer: %v", err) } if val, ok := config.params.Get(codeResponseTypeKey); ok && val.(bool) { if val, ok := config.params.GetString(clientSecretKey); !ok || val == "" { return nil, "", errors.New("Client Secret must not be empty for Code Response Type") } if val, ok := config.params.Get(idTokenResponseTypeKey); ok && val.(bool) { return nil, "", errors.New("Only one response type may be chosen") } } else if ok && !val.(bool) { if val, ok := config.params.Get(idTokenResponseTypeKey); ok && !val.(bool) { return nil, "", errors.New("At least one response type must be returned") } } return config.params, config.id, nil } // OIDCProviderConfigToUpdate represents the options used to update an existing OIDCProviderConfig. type OIDCProviderConfigToUpdate struct { params nestedMap } // ClientID updates the client ID of the config. func (config *OIDCProviderConfigToUpdate) ClientID(clientID string) *OIDCProviderConfigToUpdate { return config.set(clientIDKey, clientID) } // Issuer updates the issuer of the config. func (config *OIDCProviderConfigToUpdate) Issuer(issuer string) *OIDCProviderConfigToUpdate { return config.set(issuerKey, issuer) } // DisplayName updates the DisplayName field of the config. func (config *OIDCProviderConfigToUpdate) DisplayName(name string) *OIDCProviderConfigToUpdate { var nameOrNil interface{} if name != "" { nameOrNil = name } return config.set(displayNameKey, nameOrNil) } // Enabled enables or disables the config. func (config *OIDCProviderConfigToUpdate) Enabled(enabled bool) *OIDCProviderConfigToUpdate { return config.set(enabledKey, enabled) } // ClientSecret sets the client secret for the provider. // This is required for the code flow. func (config *OIDCProviderConfigToUpdate) ClientSecret(secret string) *OIDCProviderConfigToUpdate { return config.set(clientSecretKey, secret) } // IDTokenResponseType sets whether to enable the ID token response flow for the provider. // By default, this is enabled if no response type is specified. // Having both the code and ID token response flows is currently not supported. func (config *OIDCProviderConfigToUpdate) IDTokenResponseType(enabled bool) *OIDCProviderConfigToUpdate { return config.set(idTokenResponseTypeKey, enabled) } // CodeResponseType sets whether to enable the code response flow for the new provider. // By default, this is not enabled if no response type is specified. // A client secret must be set for this response type. // Having both the code and ID token response flows is currently not supported. func (config *OIDCProviderConfigToUpdate) CodeResponseType(enabled bool) *OIDCProviderConfigToUpdate { return config.set(codeResponseTypeKey, enabled) } func (config *OIDCProviderConfigToUpdate) set(key string, value interface{}) *OIDCProviderConfigToUpdate { if config.params == nil { config.params = make(nestedMap) } config.params.Set(key, value) return config } func (config *OIDCProviderConfigToUpdate) buildRequest() (nestedMap, error) { if len(config.params) == 0 { return nil, errors.New("no parameters specified in the update request") } if val, ok := config.params.GetString(clientIDKey); ok && val == "" { return nil, errors.New("ClientID must not be empty") } if val, ok := config.params.GetString(issuerKey); ok { if val == "" { return nil, errors.New("Issuer must not be empty") } if _, err := url.ParseRequestURI(val); err != nil { return nil, fmt.Errorf("failed to parse Issuer: %v", err) } } if val, ok := config.params.Get(codeResponseTypeKey); ok && val.(bool) { if val, ok := config.params.GetString(clientSecretKey); !ok || val == "" { return nil, errors.New("Client Secret must not be empty for Code Response Type") } if val, ok := config.params.Get(idTokenResponseTypeKey); ok && val.(bool) { return nil, errors.New("Only one response type may be chosen") } } else if ok && !val.(bool) { if val, ok := config.params.Get(idTokenResponseTypeKey); ok && !val.(bool) { return nil, errors.New("At least one response type must be returned") } } return config.params, nil } // OIDCProviderConfigIterator is an iterator over OIDC provider configurations. type OIDCProviderConfigIterator struct { client *baseClient ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo configs []*OIDCProviderConfig } // PageInfo supports pagination. func (it *OIDCProviderConfigIterator) PageInfo() *iterator.PageInfo { return it.pageInfo } // Next returns the next OIDCProviderConfig. The error value of [iterator.Done] is // returned if there are no more results. Once Next returns [iterator.Done], all // subsequent calls will return [iterator.Done]. func (it *OIDCProviderConfigIterator) Next() (*OIDCProviderConfig, error) { if err := it.nextFunc(); err != nil { return nil, err } config := it.configs[0] it.configs = it.configs[1:] return config, nil } func (it *OIDCProviderConfigIterator) fetch(pageSize int, pageToken string) (string, error) { params := map[string]string{ "pageSize": strconv.Itoa(pageSize), } if pageToken != "" { params["pageToken"] = pageToken } req := &internal.Request{ Method: http.MethodGet, URL: "/oauthIdpConfigs", Opts: []internal.HTTPOption{ internal.WithQueryParams(params), }, } var result struct { Configs []oidcProviderConfigDAO `json:"oauthIdpConfigs"` NextPageToken string `json:"nextPageToken"` } if _, err := it.client.makeRequest(it.ctx, req, &result); err != nil { return "", err } for _, config := range result.Configs { it.configs = append(it.configs, config.toOIDCProviderConfig()) } it.pageInfo.Token = result.NextPageToken return result.NextPageToken, nil } // SAMLProviderConfig is the SAML auth provider configuration. // See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html. type SAMLProviderConfig struct { ID string DisplayName string Enabled bool IDPEntityID string SSOURL string RequestSigningEnabled bool X509Certificates []string RPEntityID string CallbackURL string } // SAMLProviderConfigToCreate represents the options used to create a new SAMLProviderConfig. type SAMLProviderConfigToCreate struct { id string params nestedMap } // ID sets the provider ID of the new config. func (config *SAMLProviderConfigToCreate) ID(id string) *SAMLProviderConfigToCreate { config.id = id return config } // IDPEntityID sets the IDPEntityID field of the new config. func (config *SAMLProviderConfigToCreate) IDPEntityID(entityID string) *SAMLProviderConfigToCreate { return config.set(idpEntityIDKey, entityID) } // SSOURL sets the SSOURL field of the new config. func (config *SAMLProviderConfigToCreate) SSOURL(url string) *SAMLProviderConfigToCreate { return config.set(ssoURLKey, url) } // RequestSigningEnabled enables or disables the request signing support. func (config *SAMLProviderConfigToCreate) RequestSigningEnabled(enabled bool) *SAMLProviderConfigToCreate { return config.set(signRequestKey, enabled) } // X509Certificates sets the certificates for the new config. func (config *SAMLProviderConfigToCreate) X509Certificates(certs []string) *SAMLProviderConfigToCreate { var result []idpCertificate for _, cert := range certs { result = append(result, idpCertificate{cert}) } return config.set(idpCertsKey, result) } // RPEntityID sets the RPEntityID field of the new config. func (config *SAMLProviderConfigToCreate) RPEntityID(entityID string) *SAMLProviderConfigToCreate { return config.set(spEntityIDKey, entityID) } // CallbackURL sets the CallbackURL field of the new config. func (config *SAMLProviderConfigToCreate) CallbackURL(url string) *SAMLProviderConfigToCreate { return config.set(callbackURIKey, url) } // DisplayName sets the DisplayName field of the new config. func (config *SAMLProviderConfigToCreate) DisplayName(name string) *SAMLProviderConfigToCreate { return config.set(displayNameKey, name) } // Enabled enables or disables the new config. func (config *SAMLProviderConfigToCreate) Enabled(enabled bool) *SAMLProviderConfigToCreate { return config.set(enabledKey, enabled) } func (config *SAMLProviderConfigToCreate) set(key string, value interface{}) *SAMLProviderConfigToCreate { if config.params == nil { config.params = make(nestedMap) } config.params.Set(key, value) return config } func (config *SAMLProviderConfigToCreate) buildRequest() (nestedMap, string, error) { if err := validateSAMLConfigID(config.id); err != nil { return nil, "", err } if len(config.params) == 0 { return nil, "", errors.New("no parameters specified in the create request") } if val, ok := config.params.GetString(idpEntityIDKey); !ok || val == "" { return nil, "", errors.New("IDPEntityID must not be empty") } if val, ok := config.params.GetString(ssoURLKey); !ok || val == "" { return nil, "", errors.New("SSOURL must not be empty") } else if _, err := url.ParseRequestURI(val); err != nil { return nil, "", fmt.Errorf("failed to parse SSOURL: %v", err) } var certs interface{} var ok bool if certs, ok = config.params.Get(idpCertsKey); !ok || len(certs.([]idpCertificate)) == 0 { return nil, "", errors.New("X509Certificates must not be empty") } for _, cert := range certs.([]idpCertificate) { if cert.X509Certificate == "" { return nil, "", errors.New("X509Certificates must not contain empty strings") } } if val, ok := config.params.GetString(spEntityIDKey); !ok || val == "" { return nil, "", errors.New("RPEntityID must not be empty") } if val, ok := config.params.GetString(callbackURIKey); !ok || val == "" { return nil, "", errors.New("CallbackURL must not be empty") } else if _, err := url.ParseRequestURI(val); err != nil { return nil, "", fmt.Errorf("failed to parse CallbackURL: %v", err) } return config.params, config.id, nil } // SAMLProviderConfigToUpdate represents the options used to update an existing SAMLProviderConfig. type SAMLProviderConfigToUpdate struct { params nestedMap } // IDPEntityID the IDPEntityID field of the config. func (config *SAMLProviderConfigToUpdate) IDPEntityID(entityID string) *SAMLProviderConfigToUpdate { return config.set(idpEntityIDKey, entityID) } // SSOURL updates the SSOURL field of the config. func (config *SAMLProviderConfigToUpdate) SSOURL(url string) *SAMLProviderConfigToUpdate { return config.set(ssoURLKey, url) } // RequestSigningEnabled enables or disables the request signing support. func (config *SAMLProviderConfigToUpdate) RequestSigningEnabled(enabled bool) *SAMLProviderConfigToUpdate { return config.set(signRequestKey, enabled) } // X509Certificates updates the certificates of the config. func (config *SAMLProviderConfigToUpdate) X509Certificates(certs []string) *SAMLProviderConfigToUpdate { var result []idpCertificate for _, cert := range certs { result = append(result, idpCertificate{cert}) } return config.set(idpCertsKey, result) } // RPEntityID updates the RPEntityID field of the config. func (config *SAMLProviderConfigToUpdate) RPEntityID(entityID string) *SAMLProviderConfigToUpdate { return config.set(spEntityIDKey, entityID) } // CallbackURL updates the CallbackURL field of the config. func (config *SAMLProviderConfigToUpdate) CallbackURL(url string) *SAMLProviderConfigToUpdate { return config.set(callbackURIKey, url) } // DisplayName updates the DisplayName field of the config. func (config *SAMLProviderConfigToUpdate) DisplayName(name string) *SAMLProviderConfigToUpdate { var nameOrNil interface{} if name != "" { nameOrNil = name } return config.set(displayNameKey, nameOrNil) } // Enabled enables or disables the config. func (config *SAMLProviderConfigToUpdate) Enabled(enabled bool) *SAMLProviderConfigToUpdate { return config.set(enabledKey, enabled) } func (config *SAMLProviderConfigToUpdate) set(key string, value interface{}) *SAMLProviderConfigToUpdate { if config.params == nil { config.params = make(nestedMap) } config.params.Set(key, value) return config } func (config *SAMLProviderConfigToUpdate) buildRequest() (nestedMap, error) { if len(config.params) == 0 { return nil, errors.New("no parameters specified in the update request") } if val, ok := config.params.GetString(idpEntityIDKey); ok && val == "" { return nil, errors.New("IDPEntityID must not be empty") } if val, ok := config.params.GetString(ssoURLKey); ok { if val == "" { return nil, errors.New("SSOURL must not be empty") } if _, err := url.ParseRequestURI(val); err != nil { return nil, fmt.Errorf("failed to parse SSOURL: %v", err) } } if val, ok := config.params.Get(idpCertsKey); ok { if len(val.([]idpCertificate)) == 0 { return nil, errors.New("X509Certificates must not be empty") } for _, cert := range val.([]idpCertificate) { if cert.X509Certificate == "" { return nil, errors.New("X509Certificates must not contain empty strings") } } } if val, ok := config.params.GetString(spEntityIDKey); ok && val == "" { return nil, errors.New("RPEntityID must not be empty") } if val, ok := config.params.GetString(callbackURIKey); ok { if val == "" { return nil, errors.New("CallbackURL must not be empty") } if _, err := url.ParseRequestURI(val); err != nil { return nil, fmt.Errorf("failed to parse CallbackURL: %v", err) } } return config.params, nil } // SAMLProviderConfigIterator is an iterator over SAML provider configurations. type SAMLProviderConfigIterator struct { client *baseClient ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo configs []*SAMLProviderConfig } // PageInfo supports pagination. func (it *SAMLProviderConfigIterator) PageInfo() *iterator.PageInfo { return it.pageInfo } // Next returns the next SAMLProviderConfig. The error value of [iterator.Done] is // returned if there are no more results. Once Next returns [iterator.Done], all // subsequent calls will return [iterator.Done]. func (it *SAMLProviderConfigIterator) Next() (*SAMLProviderConfig, error) { if err := it.nextFunc(); err != nil { return nil, err } config := it.configs[0] it.configs = it.configs[1:] return config, nil } func (it *SAMLProviderConfigIterator) fetch(pageSize int, pageToken string) (string, error) { params := map[string]string{ "pageSize": strconv.Itoa(pageSize), } if pageToken != "" { params["pageToken"] = pageToken } req := &internal.Request{ Method: http.MethodGet, URL: "/inboundSamlConfigs", Opts: []internal.HTTPOption{ internal.WithQueryParams(params), }, } var result struct { Configs []samlProviderConfigDAO `json:"inboundSamlConfigs"` NextPageToken string `json:"nextPageToken"` } if _, err := it.client.makeRequest(it.ctx, req, &result); err != nil { return "", err } for _, config := range result.Configs { it.configs = append(it.configs, config.toSAMLProviderConfig()) } it.pageInfo.Token = result.NextPageToken return result.NextPageToken, nil } // OIDCProviderConfig returns the OIDCProviderConfig with the given ID. func (c *baseClient) OIDCProviderConfig(ctx context.Context, id string) (*OIDCProviderConfig, error) { if err := validateOIDCConfigID(id); err != nil { return nil, err } req := &internal.Request{ Method: http.MethodGet, URL: fmt.Sprintf("/oauthIdpConfigs/%s", id), } var result oidcProviderConfigDAO if _, err := c.makeRequest(ctx, req, &result); err != nil { return nil, err } return result.toOIDCProviderConfig(), nil } // CreateOIDCProviderConfig creates a new OIDC provider config from the given parameters. func (c *baseClient) CreateOIDCProviderConfig(ctx context.Context, config *OIDCProviderConfigToCreate) (*OIDCProviderConfig, error) { if config == nil { return nil, errors.New("config must not be nil") } body, id, err := config.buildRequest() if err != nil { return nil, err } req := &internal.Request{ Method: http.MethodPost, URL: "/oauthIdpConfigs", Body: internal.NewJSONEntity(body), Opts: []internal.HTTPOption{ internal.WithQueryParam("oauthIdpConfigId", id), }, } var result oidcProviderConfigDAO if _, err := c.makeRequest(ctx, req, &result); err != nil { return nil, err } return result.toOIDCProviderConfig(), nil } // UpdateOIDCProviderConfig updates an existing OIDC provider config with the given parameters. func (c *baseClient) UpdateOIDCProviderConfig(ctx context.Context, id string, config *OIDCProviderConfigToUpdate) (*OIDCProviderConfig, error) { if err := validateOIDCConfigID(id); err != nil { return nil, err } if config == nil { return nil, errors.New("config must not be nil") } body, err := config.buildRequest() if err != nil { return nil, err } mask := body.UpdateMask() req := &internal.Request{ Method: http.MethodPatch, URL: fmt.Sprintf("/oauthIdpConfigs/%s", id), Body: internal.NewJSONEntity(body), Opts: []internal.HTTPOption{ internal.WithQueryParam("updateMask", strings.Join(mask, ",")), }, } var result oidcProviderConfigDAO if _, err := c.makeRequest(ctx, req, &result); err != nil { return nil, err } return result.toOIDCProviderConfig(), nil } // DeleteOIDCProviderConfig deletes the OIDCProviderConfig with the given ID. func (c *baseClient) DeleteOIDCProviderConfig(ctx context.Context, id string) error { if err := validateOIDCConfigID(id); err != nil { return err } req := &internal.Request{ Method: http.MethodDelete, URL: fmt.Sprintf("/oauthIdpConfigs/%s", id), } _, err := c.makeRequest(ctx, req, nil) return err } // OIDCProviderConfigs returns an iterator over OIDC provider configurations. // // If nextPageToken is empty, the iterator will start at the beginning. Otherwise, // iterator starts after the token. func (c *baseClient) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *OIDCProviderConfigIterator { it := &OIDCProviderConfigIterator{ ctx: ctx, client: c, } it.pageInfo, it.nextFunc = iterator.NewPageInfo( it.fetch, func() int { return len(it.configs) }, func() interface{} { b := it.configs; it.configs = nil; return b }) it.pageInfo.MaxSize = maxConfigs it.pageInfo.Token = nextPageToken return it } // SAMLProviderConfig returns the SAMLProviderConfig with the given ID. func (c *baseClient) SAMLProviderConfig(ctx context.Context, id string) (*SAMLProviderConfig, error) { if err := validateSAMLConfigID(id); err != nil { return nil, err } req := &internal.Request{ Method: http.MethodGet, URL: fmt.Sprintf("/inboundSamlConfigs/%s", id), } var result samlProviderConfigDAO if _, err := c.makeRequest(ctx, req, &result); err != nil { return nil, err } return result.toSAMLProviderConfig(), nil } // CreateSAMLProviderConfig creates a new SAML provider config from the given parameters. func (c *baseClient) CreateSAMLProviderConfig(ctx context.Context, config *SAMLProviderConfigToCreate) (*SAMLProviderConfig, error) { if config == nil { return nil, errors.New("config must not be nil") } body, id, err := config.buildRequest() if err != nil { return nil, err } req := &internal.Request{ Method: http.MethodPost, URL: "/inboundSamlConfigs", Body: internal.NewJSONEntity(body), Opts: []internal.HTTPOption{ internal.WithQueryParam("inboundSamlConfigId", id), }, } var result samlProviderConfigDAO if _, err := c.makeRequest(ctx, req, &result); err != nil { return nil, err } return result.toSAMLProviderConfig(), nil } // UpdateSAMLProviderConfig updates an existing SAML provider config with the given parameters. func (c *baseClient) UpdateSAMLProviderConfig(ctx context.Context, id string, config *SAMLProviderConfigToUpdate) (*SAMLProviderConfig, error) { if err := validateSAMLConfigID(id); err != nil { return nil, err } if config == nil { return nil, errors.New("config must not be nil") } body, err := config.buildRequest() if err != nil { return nil, err } mask := body.UpdateMask() req := &internal.Request{ Method: http.MethodPatch, URL: fmt.Sprintf("/inboundSamlConfigs/%s", id), Body: internal.NewJSONEntity(body), Opts: []internal.HTTPOption{ internal.WithQueryParam("updateMask", strings.Join(mask, ",")), }, } var result samlProviderConfigDAO if _, err := c.makeRequest(ctx, req, &result); err != nil { return nil, err } return result.toSAMLProviderConfig(), nil } // DeleteSAMLProviderConfig deletes the SAMLProviderConfig with the given ID. func (c *baseClient) DeleteSAMLProviderConfig(ctx context.Context, id string) error { if err := validateSAMLConfigID(id); err != nil { return err } req := &internal.Request{ Method: http.MethodDelete, URL: fmt.Sprintf("/inboundSamlConfigs/%s", id), } _, err := c.makeRequest(ctx, req, nil) return err } // SAMLProviderConfigs returns an iterator over SAML provider configurations. // // If nextPageToken is empty, the iterator will start at the beginning. Otherwise, // iterator starts after the token. func (c *baseClient) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *SAMLProviderConfigIterator { it := &SAMLProviderConfigIterator{ ctx: ctx, client: c, } it.pageInfo, it.nextFunc = iterator.NewPageInfo( it.fetch, func() int { return len(it.configs) }, func() interface{} { b := it.configs; it.configs = nil; return b }) it.pageInfo.MaxSize = maxConfigs it.pageInfo.Token = nextPageToken return it } func (c *baseClient) makeRequest( ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) { if c.projectID == "" { return nil, errors.New("project id not available") } if c.tenantID != "" { req.URL = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.providerConfigEndpoint, c.projectID, c.tenantID, req.URL) } else { req.URL = fmt.Sprintf("%s/projects/%s%s", c.providerConfigEndpoint, c.projectID, req.URL) } return c.httpClient.DoAndUnmarshal(ctx, req, v) } type oidcProviderConfigDAO struct { Name string `json:"name"` ClientID string `json:"clientId"` Issuer string `json:"issuer"` DisplayName string `json:"displayName"` Enabled bool `json:"enabled"` ClientSecret string `json:"clientSecret"` ResponseType oidcProviderResponseType `json:"responseType"` } type oidcProviderResponseType struct { Code bool `json:"code"` IDToken bool `json:"idToken"` } func (dao *oidcProviderConfigDAO) toOIDCProviderConfig() *OIDCProviderConfig { return &OIDCProviderConfig{ ID: extractResourceID(dao.Name), DisplayName: dao.DisplayName, Enabled: dao.Enabled, ClientID: dao.ClientID, Issuer: dao.Issuer, ClientSecret: dao.ClientSecret, CodeResponseType: dao.ResponseType.Code, IDTokenResponseType: dao.ResponseType.IDToken, } } type idpCertificate struct { X509Certificate string `json:"x509Certificate"` } type samlProviderConfigDAO struct { Name string `json:"name"` IDPConfig struct { IDPEntityID string `json:"idpEntityId"` SSOURL string `json:"ssoUrl"` IDPCertificates []idpCertificate `json:"idpCertificates"` SignRequest bool `json:"signRequest"` } `json:"idpConfig"` SPConfig struct { SPEntityID string `json:"spEntityId"` CallbackURI string `json:"callbackUri"` } `json:"spConfig"` DisplayName string `json:"displayName"` Enabled bool `json:"enabled"` } func (dao *samlProviderConfigDAO) toSAMLProviderConfig() *SAMLProviderConfig { var certs []string for _, cert := range dao.IDPConfig.IDPCertificates { certs = append(certs, cert.X509Certificate) } return &SAMLProviderConfig{ ID: extractResourceID(dao.Name), DisplayName: dao.DisplayName, Enabled: dao.Enabled, IDPEntityID: dao.IDPConfig.IDPEntityID, SSOURL: dao.IDPConfig.SSOURL, RequestSigningEnabled: dao.IDPConfig.SignRequest, X509Certificates: certs, RPEntityID: dao.SPConfig.SPEntityID, CallbackURL: dao.SPConfig.CallbackURI, } } func validateOIDCConfigID(id string) error { if !strings.HasPrefix(id, "oidc.") { return fmt.Errorf("invalid OIDC provider id: %q", id) } return nil } func validateSAMLConfigID(id string) error { if !strings.HasPrefix(id, "saml.") { return fmt.Errorf("invalid SAML provider id: %q", id) } return nil } func extractResourceID(name string) string { // name format: "projects/project-id/resource/resource-id" segments := strings.Split(name, "/") return segments[len(segments)-1] } golang-google-firebase-go-4.18.0/auth/provider_config_test.go000066400000000000000000001222131505612111400242110ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/json" "fmt" "net/http" "reflect" "sort" "strings" "testing" "firebase.google.com/go/v4/errorutils" "google.golang.org/api/iterator" ) const oidcConfigResponse = `{ "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider", "clientId": "CLIENT_ID", "issuer": "https://oidc.com/issuer", "displayName": "oidcProviderName", "enabled": true, "clientSecret": "CLIENT_SECRET", "responseType": { "code": true, "idToken": true } }` const samlConfigResponse = `{ "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider", "idpConfig": { "idpEntityId": "IDP_ENTITY_ID", "ssoUrl": "https://example.com/login", "signRequest": true, "idpCertificates": [ {"x509Certificate": "CERT1"}, {"x509Certificate": "CERT2"} ] }, "spConfig": { "spEntityId": "RP_ENTITY_ID", "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" }, "displayName": "samlProviderName", "enabled": true }` const notFoundResponse = `{ "error": { "message": "CONFIGURATION_NOT_FOUND" } }` var idpCertsMap = []interface{}{ map[string]interface{}{"x509Certificate": "CERT1"}, map[string]interface{}{"x509Certificate": "CERT2"}, } var oidcProviderConfig = &OIDCProviderConfig{ ID: "oidc.provider", DisplayName: "oidcProviderName", Enabled: true, ClientID: "CLIENT_ID", Issuer: "https://oidc.com/issuer", ClientSecret: "CLIENT_SECRET", CodeResponseType: true, IDTokenResponseType: true, } var samlProviderConfig = &SAMLProviderConfig{ ID: "saml.provider", DisplayName: "samlProviderName", Enabled: true, IDPEntityID: "IDP_ENTITY_ID", SSOURL: "https://example.com/login", RequestSigningEnabled: true, X509Certificates: []string{"CERT1", "CERT2"}, RPEntityID: "RP_ENTITY_ID", CallbackURL: "https://projectId.firebaseapp.com/__/auth/handler", } var invalidOIDCConfigIDs = []string{ "", "invalid.id", "saml.config", } var invalidSAMLConfigIDs = []string{ "", "invalid.id", "oidc.config", } func TestOIDCProviderConfig(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client oidc, err := client.OIDCProviderConfig(context.Background(), "oidc.provider") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("OIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } req := s.Req[0] if req.Method != http.MethodGet { t.Errorf("OIDCProviderConfig() Method = %q; want = %q", req.Method, http.MethodGet) } wantURL := "/projects/mock-project-id/oauthIdpConfigs/oidc.provider" if req.URL.Path != wantURL { t.Errorf("OIDCProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestOIDCProviderConfigInvalidID(t *testing.T) { client := &baseClient{} wantErr := "invalid OIDC provider id: " for _, id := range invalidOIDCConfigIDs { saml, err := client.OIDCProviderConfig(context.Background(), id) if saml != nil || err == nil || !strings.HasPrefix(err.Error(), wantErr) { t.Errorf("OIDCProviderConfig(%q) = (%v, %v); want = (nil, %q)", id, saml, err, wantErr) } } } func TestOIDCProviderConfigError(t *testing.T) { s := echoServer([]byte(notFoundResponse), t) defer s.Close() s.Status = http.StatusNotFound client := s.Client saml, err := client.OIDCProviderConfig(context.Background(), "oidc.provider") if saml != nil || err == nil || !IsConfigurationNotFound(err) { t.Errorf("OIDCProviderConfig() = (%v, %v); want = (nil, ConfigurationNotFound)", saml, err) } } func TestCreateOIDCProviderConfig(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client options := (&OIDCProviderConfigToCreate{}). ID(oidcProviderConfig.ID). DisplayName(oidcProviderConfig.DisplayName). Enabled(oidcProviderConfig.Enabled). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer). ClientSecret(oidcProviderConfig.ClientSecret). CodeResponseType(true). IDTokenResponseType(false) oidc, err := client.CreateOIDCProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("CreateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": oidcProviderConfig.DisplayName, "enabled": oidcProviderConfig.Enabled, "clientId": oidcProviderConfig.ClientID, "issuer": oidcProviderConfig.Issuer, "clientSecret": oidcProviderConfig.ClientSecret, "responseType": map[string]interface{}{ "code": true, "idToken": false, }, } if err := checkCreateOIDCConfigRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateOIDCProviderConfigMinimal(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client options := (&OIDCProviderConfigToCreate{}). ID(oidcProviderConfig.ID). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer). IDTokenResponseType(true) oidc, err := client.CreateOIDCProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("CreateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "clientId": oidcProviderConfig.ClientID, "issuer": oidcProviderConfig.Issuer, "responseType": map[string]interface{}{"idToken": true}, } if err := checkCreateOIDCConfigRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateOIDCProviderConfigZeroValues(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client options := (&OIDCProviderConfigToCreate{}). ID(oidcProviderConfig.ID). DisplayName(""). Enabled(false). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer). CodeResponseType(false). IDTokenResponseType(true) oidc, err := client.CreateOIDCProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("CreateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": "", "enabled": false, "clientId": oidcProviderConfig.ClientID, "issuer": oidcProviderConfig.Issuer, "responseType": map[string]interface{}{ "code": false, "idToken": true, }, } if err := checkCreateOIDCConfigRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateOIDCProviderConfigError(t *testing.T) { s := echoServer([]byte("{}"), t) s.Status = http.StatusInternalServerError defer s.Close() client := s.Client client.baseClient.httpClient.RetryConfig = nil options := (&OIDCProviderConfigToCreate{}). ID(oidcProviderConfig.ID). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer). IDTokenResponseType(true) oidc, err := client.CreateOIDCProviderConfig(context.Background(), options) if oidc != nil || !errorutils.IsInternal(err) { t.Errorf("CreateOIDCProviderConfig() = (%v, %v); want = (nil, %q)", oidc, err, "internal-error") } } func TestCreateOIDCProviderConfigInvalidInput(t *testing.T) { cases := []struct { name string want string conf *OIDCProviderConfigToCreate }{ { name: "NilConfig", want: "config must not be nil", conf: nil, }, { name: "EmptyID", want: "invalid OIDC provider id: ", conf: &OIDCProviderConfigToCreate{}, }, { name: "InvalidID", want: "invalid OIDC provider id: ", conf: (&OIDCProviderConfigToCreate{}). ID("saml.provider"), }, { name: "EmptyOptions", want: "no parameters specified in the create request", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"), }, { name: "EmptyClientID", want: "ClientID must not be empty", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"). ClientID(""), }, { name: "EmptyIssuer", want: "Issuer must not be empty", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"). ClientID("CLIENT_ID"), }, { name: "InvalidIssuer", want: "failed to parse Issuer: ", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"). ClientID("CLIENT_ID"). Issuer("not a url"), }, { name: "MissingClientSecret", want: "Client Secret must not be empty for Code Response Type", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"). ClientID("CLIENT_ID"). Issuer("https://oidc.com/issuer"). CodeResponseType(true), }, { name: "TwoResponseTypes", want: "Only one response type may be chosen", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"). ClientID("CLIENT_ID"). Issuer("https://oidc.com/issuer"). IDTokenResponseType(true). CodeResponseType(true). ClientSecret("secret"), }, { name: "ZeroResponseTypes", want: "At least one response type must be returned", conf: (&OIDCProviderConfigToCreate{}). ID("oidc.provider"). ClientID("CLIENT_ID"). Issuer("https://oidc.com/issuer"). IDTokenResponseType(false). CodeResponseType(false), }, } client := &baseClient{} for _, tc := range cases { _, err := client.CreateOIDCProviderConfig(context.Background(), tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { t.Errorf("CreateOIDCProviderConfig(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestUpdateOIDCProviderConfig(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client options := (&OIDCProviderConfigToUpdate{}). DisplayName(oidcProviderConfig.DisplayName). Enabled(oidcProviderConfig.Enabled). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer). ClientSecret(oidcProviderConfig.ClientSecret). CodeResponseType(true). IDTokenResponseType(false) oidc, err := client.UpdateOIDCProviderConfig(context.Background(), "oidc.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("UpdateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": oidcProviderConfig.DisplayName, "enabled": oidcProviderConfig.Enabled, "clientId": oidcProviderConfig.ClientID, "issuer": oidcProviderConfig.Issuer, "clientSecret": oidcProviderConfig.ClientSecret, "responseType": map[string]interface{}{ "code": true, "idToken": false, }, } wantMask := []string{ "clientId", "clientSecret", "displayName", "enabled", "issuer", "responseType.code", "responseType.idToken", } if err := checkUpdateOIDCConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateOIDCProviderConfigMinimal(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client options := (&OIDCProviderConfigToUpdate{}). DisplayName("Other name") oidc, err := client.UpdateOIDCProviderConfig(context.Background(), "oidc.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("UpdateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": "Other name", } wantMask := []string{ "displayName", } if err := checkUpdateOIDCConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateOIDCProviderConfigZeroValues(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client := s.Client options := (&OIDCProviderConfigToUpdate{}). DisplayName(""). Enabled(false). CodeResponseType(false). IDTokenResponseType(true) oidc, err := client.UpdateOIDCProviderConfig(context.Background(), "oidc.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("UpdateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": nil, "enabled": false, "responseType": map[string]interface{}{ "code": false, "idToken": true, }, } wantMask := []string{ "displayName", "enabled", "responseType.code", "responseType.idToken", } if err := checkUpdateOIDCConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateOIDCProviderConfigInvalidID(t *testing.T) { cases := []string{"", "saml.config"} client := &baseClient{} options := (&OIDCProviderConfigToUpdate{}). DisplayName("") want := "invalid OIDC provider id: " for _, tc := range cases { _, err := client.UpdateOIDCProviderConfig(context.Background(), tc, options) if err == nil || !strings.HasPrefix(err.Error(), want) { t.Errorf("UpdateOIDCProviderConfig(%q) = %v; want = %q", tc, err, want) } } } func TestUpdateOIDCProviderConfigInvalidInput(t *testing.T) { cases := []struct { name string want string conf *OIDCProviderConfigToUpdate }{ { name: "NilConfig", want: "config must not be nil", conf: nil, }, { name: "Empty", want: "no parameters specified in the update request", conf: &OIDCProviderConfigToUpdate{}, }, { name: "EmptyClientID", want: "ClientID must not be empty", conf: (&OIDCProviderConfigToUpdate{}). ClientID(""), }, { name: "EmptyIssuer", want: "Issuer must not be empty", conf: (&OIDCProviderConfigToUpdate{}). Issuer(""), }, { name: "InvalidIssuer", want: "failed to parse Issuer: ", conf: (&OIDCProviderConfigToUpdate{}). Issuer("not a url"), }, { name: "MissingClientSecret", want: "Client Secret must not be empty for Code Response Type", conf: (&OIDCProviderConfigToUpdate{}). Issuer("https://oidc.com/issuer"). CodeResponseType(true), }, { name: "TwoResponseTypes", want: "Only one response type may be chosen", conf: (&OIDCProviderConfigToUpdate{}). Issuer("https://oidc.com/issuer"). IDTokenResponseType(true). CodeResponseType(true). ClientSecret("secret"), }, { name: "ZeroResponseTypes", want: "At least one response type must be returned", conf: (&OIDCProviderConfigToUpdate{}). Issuer("https://oidc.com/issuer"). IDTokenResponseType(false). CodeResponseType(false), }, } client := &baseClient{} for _, tc := range cases { _, err := client.UpdateOIDCProviderConfig(context.Background(), "oidc.provider", tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { t.Errorf("UpdateOIDCProviderConfig(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestDeleteOIDCProviderConfig(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client := s.Client if err := client.DeleteOIDCProviderConfig(context.Background(), "oidc.provider"); err != nil { t.Fatal(err) } req := s.Req[0] if req.Method != http.MethodDelete { t.Errorf("DeleteOIDCProviderConfig() Method = %q; want = %q", req.Method, http.MethodDelete) } wantURL := "/projects/mock-project-id/oauthIdpConfigs/oidc.provider" if req.URL.Path != wantURL { t.Errorf("DeleteOIDCProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestDeleteOIDCProviderConfigInvalidID(t *testing.T) { client := &baseClient{} wantErr := "invalid OIDC provider id: " for _, id := range invalidOIDCConfigIDs { err := client.DeleteOIDCProviderConfig(context.Background(), id) if err == nil || !strings.HasPrefix(err.Error(), wantErr) { t.Errorf("DeleteOIDCProviderConfig(%q) = %v; want = %q", id, err, wantErr) } } } func TestDeleteOIDCProviderConfigError(t *testing.T) { s := echoServer([]byte(notFoundResponse), t) defer s.Close() s.Status = http.StatusNotFound client := s.Client err := client.DeleteOIDCProviderConfig(context.Background(), "oidc.provider") if err == nil || !IsConfigurationNotFound(err) { t.Errorf("DeleteOIDCProviderConfig() = %v; want = ConfigurationNotFound", err) } } func TestOIDCProviderConfigs(t *testing.T) { template := `{ "oauthIdpConfigs": [ %s, %s, %s ], "nextPageToken": "" }` response := fmt.Sprintf(template, oidcConfigResponse, oidcConfigResponse, oidcConfigResponse) s := echoServer([]byte(response), t) defer s.Close() want := []*OIDCProviderConfig{ oidcProviderConfig, oidcProviderConfig, oidcProviderConfig, } wantPath := "/projects/mock-project-id/oauthIdpConfigs" testIterator := func(iter *OIDCProviderConfigIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { config, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(config, want[i]) { t.Errorf("OIDCProviderConfigs(%q) = %#v; want = %#v", token, config, want[i]) } count++ } if count != len(want) { t.Errorf("OIDCProviderConfigs(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("OIDCProviderConfigs(%q) = %v; want = %v", token, err, iterator.Done) } url := s.Req[len(s.Req)-1].URL if url.Path != wantPath { t.Errorf("OIDCProviderConfigs(%q) = %q; want = %q", token, url.Path, wantPath) } // Check the query string of the last HTTP request made. gotReq := url.Query().Encode() if gotReq != req { t.Errorf("OIDCProviderConfigs(%q) = %q; want = %v", token, gotReq, req) } } client := s.Client testIterator( client.OIDCProviderConfigs(context.Background(), ""), "", "pageSize=100") testIterator( client.OIDCProviderConfigs(context.Background(), "pageToken"), "pageToken", "pageSize=100&pageToken=pageToken") } func TestOIDCProviderConfigsError(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() s.Status = http.StatusInternalServerError client := s.Client client.baseClient.httpClient.RetryConfig = nil it := client.OIDCProviderConfigs(context.Background(), "") config, err := it.Next() if config != nil || err == nil || !errorutils.IsInternal(err) { t.Errorf("OIDCProviderConfigs() = (%v, %v); want = (nil, %q)", config, err, "internal-error") } } func TestSAMLProviderConfig(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client saml, err := client.SAMLProviderConfig(context.Background(), "saml.provider") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("SAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } req := s.Req[0] if req.Method != http.MethodGet { t.Errorf("SAMLProviderConfig() Method = %q; want = %q", req.Method, http.MethodGet) } wantURL := "/projects/mock-project-id/inboundSamlConfigs/saml.provider" if req.URL.Path != wantURL { t.Errorf("SAMLProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestSAMLProviderConfigInvalidID(t *testing.T) { client := &baseClient{} wantErr := "invalid SAML provider id: " for _, id := range invalidSAMLConfigIDs { saml, err := client.SAMLProviderConfig(context.Background(), id) if saml != nil || err == nil || !strings.HasPrefix(err.Error(), wantErr) { t.Errorf("SAMLProviderConfig(%q) = (%v, %v); want = (nil, %q)", id, saml, err, wantErr) } } } func TestSAMLProviderConfigError(t *testing.T) { s := echoServer([]byte(notFoundResponse), t) defer s.Close() s.Status = http.StatusNotFound client := s.Client saml, err := client.SAMLProviderConfig(context.Background(), "saml.provider") if saml != nil || err == nil || !IsConfigurationNotFound(err) { t.Errorf("SAMLProviderConfig() = (%v, %v); want = (nil, ConfigurationNotFound)", saml, err) } } func TestCreateSAMLProviderConfig(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client options := (&SAMLProviderConfigToCreate{}). ID(samlProviderConfig.ID). DisplayName(samlProviderConfig.DisplayName). Enabled(samlProviderConfig.Enabled). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). RequestSigningEnabled(samlProviderConfig.RequestSigningEnabled). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.CreateSAMLProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("CreateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": samlProviderConfig.DisplayName, "enabled": samlProviderConfig.Enabled, "idpConfig": map[string]interface{}{ "idpEntityId": samlProviderConfig.IDPEntityID, "ssoUrl": samlProviderConfig.SSOURL, "signRequest": samlProviderConfig.RequestSigningEnabled, "idpCertificates": idpCertsMap, }, "spConfig": map[string]interface{}{ "spEntityId": samlProviderConfig.RPEntityID, "callbackUri": samlProviderConfig.CallbackURL, }, } if err := checkCreateSAMLConfigRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateSAMLProviderConfigMinimal(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client options := (&SAMLProviderConfigToCreate{}). ID(samlProviderConfig.ID). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.CreateSAMLProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("CreateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "idpConfig": map[string]interface{}{ "idpEntityId": samlProviderConfig.IDPEntityID, "ssoUrl": samlProviderConfig.SSOURL, "idpCertificates": idpCertsMap, }, "spConfig": map[string]interface{}{ "spEntityId": samlProviderConfig.RPEntityID, "callbackUri": samlProviderConfig.CallbackURL, }, } if err := checkCreateSAMLConfigRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateSAMLProviderConfigZeroValues(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client options := (&SAMLProviderConfigToCreate{}). ID(samlProviderConfig.ID). DisplayName(""). Enabled(false). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). RequestSigningEnabled(false). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.CreateSAMLProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("CreateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": "", "enabled": false, "idpConfig": map[string]interface{}{ "idpEntityId": samlProviderConfig.IDPEntityID, "ssoUrl": samlProviderConfig.SSOURL, "signRequest": false, "idpCertificates": idpCertsMap, }, "spConfig": map[string]interface{}{ "spEntityId": samlProviderConfig.RPEntityID, "callbackUri": samlProviderConfig.CallbackURL, }, } if err := checkCreateSAMLConfigRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateSAMLProviderConfigError(t *testing.T) { s := echoServer([]byte("{}"), t) s.Status = http.StatusInternalServerError defer s.Close() client := s.Client client.baseClient.httpClient.RetryConfig = nil options := (&SAMLProviderConfigToCreate{}). ID(samlProviderConfig.ID). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.CreateSAMLProviderConfig(context.Background(), options) if saml != nil || !errorutils.IsInternal(err) { t.Errorf("CreateSAMLProviderConfig() = (%v, %v); want = (nil, %q)", saml, err, "internal-error") } } func TestCreateSAMLProviderConfigInvalidInput(t *testing.T) { cases := []struct { name string want string conf *SAMLProviderConfigToCreate }{ { name: "NilConfig", want: "config must not be nil", conf: nil, }, { name: "EmptyID", want: "invalid SAML provider id: ", conf: &SAMLProviderConfigToCreate{}, }, { name: "InvalidID", want: "invalid SAML provider id: ", conf: (&SAMLProviderConfigToCreate{}). ID("oidc.provider"), }, { name: "EmptyOptions", want: "no parameters specified in the create request", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"), }, { name: "EmptyEntityID", want: "IDPEntityID must not be empty", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID(""), }, { name: "EmptySSOURL", want: "SSOURL must not be empty", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"), }, { name: "InvalidSSOURL", want: "failed to parse SSOURL: ", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("not a url"), }, { name: "EmptyX509Certs", want: "X509Certificates must not be empty", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/login"), }, { name: "EmptyStringInX509Certs", want: "X509Certificates must not contain empty strings", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/login"). X509Certificates([]string{""}), }, { name: "EmptyRPEntityID", want: "RPEntityID must not be empty", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/login"). X509Certificates([]string{"CERT"}), }, { name: "EmptyCallbackURL", want: "CallbackURL must not be empty", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/login"). X509Certificates([]string{"CERT"}). RPEntityID("RP_ENTITY_ID"), }, { name: "InvalidCallbackURL", want: "failed to parse CallbackURL: ", conf: (&SAMLProviderConfigToCreate{}). ID("saml.provider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/login"). X509Certificates([]string{"CERT"}). RPEntityID("RP_ENTITY_ID"). CallbackURL("not a url"), }, } client := &baseClient{} for _, tc := range cases { _, err := client.CreateSAMLProviderConfig(context.Background(), tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { t.Errorf("CreateSAMLProviderConfig(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestUpdateSAMLProviderConfig(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client options := (&SAMLProviderConfigToUpdate{}). DisplayName(samlProviderConfig.DisplayName). Enabled(samlProviderConfig.Enabled). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). RequestSigningEnabled(samlProviderConfig.RequestSigningEnabled). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.UpdateSAMLProviderConfig(context.Background(), "saml.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("UpdateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": samlProviderConfig.DisplayName, "enabled": samlProviderConfig.Enabled, "idpConfig": map[string]interface{}{ "idpEntityId": samlProviderConfig.IDPEntityID, "ssoUrl": samlProviderConfig.SSOURL, "signRequest": samlProviderConfig.RequestSigningEnabled, "idpCertificates": idpCertsMap, }, "spConfig": map[string]interface{}{ "spEntityId": samlProviderConfig.RPEntityID, "callbackUri": samlProviderConfig.CallbackURL, }, } wantMask := []string{ "displayName", "enabled", "idpConfig.idpCertificates", "idpConfig.idpEntityId", "idpConfig.signRequest", "idpConfig.ssoUrl", "spConfig.callbackUri", "spConfig.spEntityId", } if err := checkUpdateSAMLConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateSAMLProviderConfigMinimal(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client options := (&SAMLProviderConfigToUpdate{}). DisplayName("Other name") saml, err := client.UpdateSAMLProviderConfig(context.Background(), "saml.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("UpdateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": "Other name", } wantMask := []string{ "displayName", } if err := checkUpdateSAMLConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateSAMLProviderConfigZeroValues(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client := s.Client options := (&SAMLProviderConfigToUpdate{}). DisplayName(""). Enabled(false). RequestSigningEnabled(false) saml, err := client.UpdateSAMLProviderConfig(context.Background(), "saml.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("UpdateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": nil, "enabled": false, "idpConfig": map[string]interface{}{ "signRequest": false, }, } wantMask := []string{ "displayName", "enabled", "idpConfig.signRequest", } if err := checkUpdateSAMLConfigRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateSAMLProviderConfigInvalidID(t *testing.T) { cases := []string{"", "oidc.config"} client := &baseClient{} options := (&SAMLProviderConfigToUpdate{}). DisplayName(""). Enabled(false). RequestSigningEnabled(false) want := "invalid SAML provider id: " for _, tc := range cases { _, err := client.UpdateSAMLProviderConfig(context.Background(), tc, options) if err == nil || !strings.HasPrefix(err.Error(), want) { t.Errorf("UpdateSAMLProviderConfig(%q) = %v; want = %q", tc, err, want) } } } func TestUpdateSAMLProviderConfigInvalidInput(t *testing.T) { cases := []struct { name string want string conf *SAMLProviderConfigToUpdate }{ { name: "NilConfig", want: "config must not be nil", conf: nil, }, { name: "Empty", want: "no parameters specified in the update request", conf: &SAMLProviderConfigToUpdate{}, }, { name: "EmptyIDPEntityID", want: "IDPEntityID must not be empty", conf: (&SAMLProviderConfigToUpdate{}). IDPEntityID(""), }, { name: "EmptySSOURL", want: "SSOURL must not be empty", conf: (&SAMLProviderConfigToUpdate{}). SSOURL(""), }, { name: "InvalidSSOURL", want: "failed to parse SSOURL: ", conf: (&SAMLProviderConfigToUpdate{}). SSOURL("not a url"), }, { name: "EmptyX509Certs", want: "X509Certificates must not be empty", conf: (&SAMLProviderConfigToUpdate{}). X509Certificates(nil), }, { name: "EmptyStringInX509Certs", want: "X509Certificates must not contain empty strings", conf: (&SAMLProviderConfigToUpdate{}). X509Certificates([]string{""}), }, { name: "EmptyRPEntityID", want: "RPEntityID must not be empty", conf: (&SAMLProviderConfigToUpdate{}). RPEntityID(""), }, { name: "EmptyCallbackURL", want: "CallbackURL must not be empty", conf: (&SAMLProviderConfigToUpdate{}). CallbackURL(""), }, { name: "InvalidCallbackURL", want: "failed to parse CallbackURL: ", conf: (&SAMLProviderConfigToUpdate{}). CallbackURL("not a url"), }, } client := &baseClient{} for _, tc := range cases { _, err := client.UpdateSAMLProviderConfig(context.Background(), "saml.provider", tc.conf) if err == nil || !strings.HasPrefix(err.Error(), tc.want) { t.Errorf("UpdateSAMLProviderConfig(%q) = %v; want = %q", tc.name, err, tc.want) } } } func TestDeleteSAMLProviderConfig(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client := s.Client if err := client.DeleteSAMLProviderConfig(context.Background(), "saml.provider"); err != nil { t.Fatal(err) } req := s.Req[0] if req.Method != http.MethodDelete { t.Errorf("DeleteSAMLProviderConfig() Method = %q; want = %q", req.Method, http.MethodDelete) } wantURL := "/projects/mock-project-id/inboundSamlConfigs/saml.provider" if req.URL.Path != wantURL { t.Errorf("DeleteSAMLProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestDeleteSAMLProviderConfigInvalidID(t *testing.T) { client := &baseClient{} wantErr := "invalid SAML provider id: " for _, id := range invalidSAMLConfigIDs { err := client.DeleteSAMLProviderConfig(context.Background(), id) if err == nil || !strings.HasPrefix(err.Error(), wantErr) { t.Errorf("DeleteSAMLProviderConfig(%q) = %v; want = %q", id, err, wantErr) } } } func TestDeleteSAMLProviderConfigError(t *testing.T) { s := echoServer([]byte(notFoundResponse), t) defer s.Close() s.Status = http.StatusNotFound client := s.Client err := client.DeleteSAMLProviderConfig(context.Background(), "saml.provider") if err == nil || !IsConfigurationNotFound(err) { t.Errorf("DeleteSAMLProviderConfig() = %v; want = ConfigurationNotFound", err) } } func TestSAMLProviderConfigs(t *testing.T) { template := `{ "inboundSamlConfigs": [ %s, %s, %s ], "nextPageToken": "" }` response := fmt.Sprintf(template, samlConfigResponse, samlConfigResponse, samlConfigResponse) s := echoServer([]byte(response), t) defer s.Close() want := []*SAMLProviderConfig{ samlProviderConfig, samlProviderConfig, samlProviderConfig, } wantPath := "/projects/mock-project-id/inboundSamlConfigs" testIterator := func(iter *SAMLProviderConfigIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { config, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(config, want[i]) { t.Errorf("SAMLProviderConfigs(%q) = %#v; want = %#v", token, config, want[i]) } count++ } if count != len(want) { t.Errorf("SAMLProviderConfigs(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("SAMLProviderConfigs(%q) = %v; want = %v", token, err, iterator.Done) } url := s.Req[len(s.Req)-1].URL if url.Path != wantPath { t.Errorf("SAMLProviderConfigs(%q) = %q; want = %q", token, url.Path, wantPath) } // Check the query string of the last HTTP request made. gotReq := url.Query().Encode() if gotReq != req { t.Errorf("SAMLProviderConfigs(%q) = %q; want = %v", token, gotReq, req) } } client := s.Client testIterator( client.SAMLProviderConfigs(context.Background(), ""), "", "pageSize=100") testIterator( client.SAMLProviderConfigs(context.Background(), "pageToken"), "pageToken", "pageSize=100&pageToken=pageToken") } func TestSAMLProviderConfigsError(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() s.Status = http.StatusInternalServerError client := s.Client client.baseClient.httpClient.RetryConfig = nil it := client.SAMLProviderConfigs(context.Background(), "") config, err := it.Next() if config != nil || err == nil || !errorutils.IsInternal(err) { t.Errorf("SAMLProviderConfigs() = (%v, %v); want = (nil, %q)", config, err, "internal-error") } } func TestSAMLProviderConfigNoProjectID(t *testing.T) { client := &baseClient{} want := "project id not available" if _, err := client.SAMLProviderConfig(context.Background(), "saml.provider"); err == nil || err.Error() != want { t.Errorf("SAMLProviderConfig() = %v; want = %q", err, want) } } func checkCreateOIDCConfigRequest(s *mockAuthServer, wantBody interface{}) error { wantURL := "/projects/mock-project-id/oauthIdpConfigs" return checkCreateOIDCConfigRequestWithURL(s, wantBody, wantURL) } func checkCreateOIDCConfigRequestWithURL(s *mockAuthServer, wantBody interface{}, wantURL string) error { req := s.Req[0] if req.Method != http.MethodPost { return fmt.Errorf("CreateOIDCProviderConfig() Method = %q; want = %q", req.Method, http.MethodPost) } if req.URL.Path != wantURL { return fmt.Errorf("CreateOIDCProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } wantQuery := "oauthIdpConfigId=oidc.provider" if req.URL.RawQuery != wantQuery { return fmt.Errorf("CreateOIDCProviderConfig() Query = %q; want = %q", req.URL.RawQuery, wantQuery) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("CreateOIDCProviderConfig() Body = %#v; want = %#v", body, wantBody) } return nil } func checkCreateSAMLConfigRequest(s *mockAuthServer, wantBody interface{}) error { wantURL := "/projects/mock-project-id/inboundSamlConfigs" return checkCreateSAMLConfigRequestWithURL(s, wantBody, wantURL) } func checkCreateSAMLConfigRequestWithURL(s *mockAuthServer, wantBody interface{}, wantURL string) error { req := s.Req[0] if req.Method != http.MethodPost { return fmt.Errorf("CreateSAMLProviderConfig() Method = %q; want = %q", req.Method, http.MethodPost) } if req.URL.Path != wantURL { return fmt.Errorf("CreateSAMLProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } wantQuery := "inboundSamlConfigId=saml.provider" if req.URL.RawQuery != wantQuery { return fmt.Errorf("CreateSAMLProviderConfig() Query = %q; want = %q", req.URL.RawQuery, wantQuery) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("CreateSAMLProviderConfig() Body = %#v; want = %#v", body, wantBody) } return nil } func checkUpdateOIDCConfigRequest(s *mockAuthServer, wantBody interface{}, wantMask []string) error { wantURL := "/projects/mock-project-id/oauthIdpConfigs/oidc.provider" return checkUpdateOIDCConfigRequestWithURL(s, wantBody, wantMask, wantURL) } func checkUpdateOIDCConfigRequestWithURL(s *mockAuthServer, wantBody interface{}, wantMask []string, wantURL string) error { req := s.Req[0] if req.Method != http.MethodPatch { return fmt.Errorf("UpdateOIDCProviderConfig() Method = %q; want = %q", req.Method, http.MethodPatch) } if req.URL.Path != wantURL { return fmt.Errorf("UpdateOIDCProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } queryParam := req.URL.Query().Get("updateMask") mask := strings.Split(queryParam, ",") sort.Strings(mask) if !reflect.DeepEqual(mask, wantMask) { return fmt.Errorf("UpdateOIDCProviderConfig() Query = %#v; want = %#v", mask, wantMask) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("UpdateOIDCProviderConfig() Body = %#v; want = %#v", body, wantBody) } return nil } func checkUpdateSAMLConfigRequest(s *mockAuthServer, wantBody interface{}, wantMask []string) error { wantURL := "/projects/mock-project-id/inboundSamlConfigs/saml.provider" return checkUpdateSAMLConfigRequestWithURL(s, wantBody, wantMask, wantURL) } func checkUpdateSAMLConfigRequestWithURL(s *mockAuthServer, wantBody interface{}, wantMask []string, wantURL string) error { req := s.Req[0] if req.Method != http.MethodPatch { return fmt.Errorf("UpdateSAMLProviderConfig() Method = %q; want = %q", req.Method, http.MethodPatch) } if req.URL.Path != wantURL { return fmt.Errorf("UpdateSAMLProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } queryParam := req.URL.Query().Get("updateMask") mask := strings.Split(queryParam, ",") sort.Strings(mask) if !reflect.DeepEqual(mask, wantMask) { return fmt.Errorf("UpdateSAMLProviderConfig() Query = %#v; want = %#v", mask, wantMask) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("UpdateSAMLProviderConfig() Body = %#v; want = %#v", body, wantBody) } return nil } golang-google-firebase-go-4.18.0/auth/tenant_mgt.go000066400000000000000000000306731505612111400221430ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "errors" "fmt" "net/http" "strconv" "strings" "firebase.google.com/go/v4/internal" "google.golang.org/api/iterator" ) // Tenant represents a tenant in a multi-tenant application. // // Multi-tenancy support requires Google Cloud's Identity Platform (GCIP). To learn more about GCIP, // including pricing and features, see https://cloud.google.com/identity-platform. // // Before multi-tenancy can be used in a Google Cloud Identity Platform project, tenants must be // enabled in that project via the Cloud Console UI. // // A tenant configuration provides information such as the display name, tenant identifier and email // authentication configuration. For OIDC/SAML provider configuration management, TenantClient // instances should be used instead of a Tenant to retrieve the list of configured IdPs on a tenant. // When configuring these providers, note that tenants will inherit whitelisted domains and // authenticated redirect URIs of their parent project. // // All other settings of a tenant will also be inherited. These will need to be managed from the // Cloud Console UI. type Tenant struct { ID string `json:"name"` DisplayName string `json:"displayName"` AllowPasswordSignUp bool `json:"allowPasswordSignup"` EnableEmailLinkSignIn bool `json:"enableEmailLinkSignin"` EnableAnonymousUsers bool `json:"enableAnonymousUser"` MultiFactorConfig *MultiFactorConfig `json:"mfaConfig"` } // TenantClient is used for managing users, configuring SAML/OIDC providers, and generating email // links for specific tenants. // // Before multi-tenancy can be used in a Google Cloud Identity Platform project, tenants must be // enabled in that project via the Cloud Console UI. // // Each tenant contains its own identity providers, settings and users. TenantClient enables // managing users and SAML/OIDC configurations of specific tenants. It also supports verifying ID // tokens issued to users who are signed into specific tenants. // // TenantClient instances for a specific tenantID can be instantiated by calling // [TenantManager.AuthForTenant(tenantID)]. type TenantClient struct { *baseClient } // TenantID returns the ID of the tenant to which this TenantClient instance belongs. func (tc *TenantClient) TenantID() string { return tc.tenantID } // TenantManager is the interface used to manage tenants in a multi-tenant application. // // This supports creating, updating, listing, deleting the tenants of a Firebase project. It also // supports creating new TenantClient instances scoped to specific tenant IDs. type TenantManager struct { base *baseClient endpoint string projectID string httpClient *internal.HTTPClient } func newTenantManager(client *internal.HTTPClient, conf *internal.AuthConfig, base *baseClient) *TenantManager { return &TenantManager{ base: base, endpoint: base.tenantMgtEndpoint, projectID: conf.ProjectID, httpClient: client, } } // AuthForTenant creates a new TenantClient scoped to a given tenantID. func (tm *TenantManager) AuthForTenant(tenantID string) (*TenantClient, error) { if tenantID == "" { return nil, errors.New("tenantID must not be empty") } return &TenantClient{ baseClient: tm.base.withTenantID(tenantID), }, nil } // Tenant returns the tenant with the given ID. func (tm *TenantManager) Tenant(ctx context.Context, tenantID string) (*Tenant, error) { if tenantID == "" { return nil, errors.New("tenantID must not be empty") } req := &internal.Request{ Method: http.MethodGet, URL: fmt.Sprintf("/tenants/%s", tenantID), } var tenant Tenant if _, err := tm.makeRequest(ctx, req, &tenant); err != nil { return nil, err } tenant.ID = extractResourceID(tenant.ID) return &tenant, nil } // CreateTenant creates a new tenant with the given options. func (tm *TenantManager) CreateTenant(ctx context.Context, tenant *TenantToCreate) (*Tenant, error) { if tenant == nil { return nil, errors.New("tenant must not be nil") } if err := tenant.validate(); err != nil { return nil, err } req := &internal.Request{ Method: http.MethodPost, URL: "/tenants", Body: internal.NewJSONEntity(tenant.ensureParams()), } var result Tenant if _, err := tm.makeRequest(ctx, req, &result); err != nil { return nil, err } result.ID = extractResourceID(result.ID) return &result, nil } // UpdateTenant updates an existing tenant with the given options. func (tm *TenantManager) UpdateTenant(ctx context.Context, tenantID string, tenant *TenantToUpdate) (*Tenant, error) { if tenantID == "" { return nil, errors.New("tenantID must not be empty") } if tenant == nil { return nil, errors.New("tenant must not be nil") } if err := tenant.validate(); err != nil { return nil, err } mask := tenant.params.UpdateMask() if len(mask) == 0 { return nil, errors.New("no parameters specified in the update request") } req := &internal.Request{ Method: http.MethodPatch, URL: fmt.Sprintf("/tenants/%s", tenantID), Body: internal.NewJSONEntity(tenant.params), Opts: []internal.HTTPOption{ internal.WithQueryParam("updateMask", strings.Join(mask, ",")), }, } var result Tenant if _, err := tm.makeRequest(ctx, req, &result); err != nil { return nil, err } result.ID = extractResourceID(result.ID) return &result, nil } // DeleteTenant deletes the tenant with the given ID. func (tm *TenantManager) DeleteTenant(ctx context.Context, tenantID string) error { if tenantID == "" { return errors.New("tenantID must not be empty") } req := &internal.Request{ Method: http.MethodDelete, URL: fmt.Sprintf("/tenants/%s", tenantID), } _, err := tm.makeRequest(ctx, req, nil) return err } // Tenants returns an iterator over tenants in the project. // // If nextPageToken is empty, the iterator will start at the beginning. Otherwise, // iterator starts after the token. func (tm *TenantManager) Tenants(ctx context.Context, nextPageToken string) *TenantIterator { it := &TenantIterator{ ctx: ctx, tm: tm, } it.pageInfo, it.nextFunc = iterator.NewPageInfo( it.fetch, func() int { return len(it.tenants) }, func() interface{} { b := it.tenants; it.tenants = nil; return b }) it.pageInfo.MaxSize = maxConfigs it.pageInfo.Token = nextPageToken return it } func (tm *TenantManager) makeRequest(ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) { if tm.projectID == "" { return nil, errors.New("project id not available") } req.URL = fmt.Sprintf("%s/projects/%s%s", tm.endpoint, tm.projectID, req.URL) return tm.httpClient.DoAndUnmarshal(ctx, req, v) } const ( tenantDisplayNameKey = "displayName" allowPasswordSignUpKey = "allowPasswordSignup" enableEmailLinkSignInKey = "enableEmailLinkSignin" enableAnonymousUser = "enableAnonymousUser" multiFactorConfigTenantKey = "mfaConfig" ) // TenantToCreate represents the options used to create a new tenant. type TenantToCreate struct { params nestedMap } // DisplayName sets the display name of the new tenant. func (t *TenantToCreate) DisplayName(name string) *TenantToCreate { return t.set(tenantDisplayNameKey, name) } // AllowPasswordSignUp enables or disables email sign-in provider. func (t *TenantToCreate) AllowPasswordSignUp(allow bool) *TenantToCreate { return t.set(allowPasswordSignUpKey, allow) } // EnableEmailLinkSignIn enables or disables email link sign-in. // // Disabling this makes the password required for email sign-in. func (t *TenantToCreate) EnableEmailLinkSignIn(enable bool) *TenantToCreate { return t.set(enableEmailLinkSignInKey, enable) } // EnableAnonymousUsers enables or disables anonymous authentication. func (t *TenantToCreate) EnableAnonymousUsers(enable bool) *TenantToCreate { return t.set(enableAnonymousUser, enable) } // MultiFactorConfig configures the tenant's multi-factor settings func (t *TenantToCreate) MultiFactorConfig(multiFactorConfig MultiFactorConfig) *TenantToCreate { return t.set(multiFactorConfigTenantKey, multiFactorConfig) } func (t *TenantToCreate) set(key string, value interface{}) *TenantToCreate { t.ensureParams().Set(key, value) return t } func (t *TenantToCreate) ensureParams() nestedMap { if t.params == nil { t.params = make(nestedMap) } return t.params } func (t *TenantToCreate) validate() error { req := make(map[string]interface{}) for k, v := range t.params { req[k] = v } val, ok := req[multiFactorConfigTenantKey] if ok { multiFactorConfig, ok := val.(MultiFactorConfig) if !ok { return fmt.Errorf("invalid type for MultiFactorConfig: %s", req[multiFactorConfigProjectKey]) } if err := multiFactorConfig.validate(); err != nil { return err } } return nil } // TenantToUpdate represents the options used to update an existing tenant. type TenantToUpdate struct { params nestedMap } // DisplayName sets the display name of the new tenant. func (t *TenantToUpdate) DisplayName(name string) *TenantToUpdate { return t.set(tenantDisplayNameKey, name) } // AllowPasswordSignUp enables or disables email sign-in provider. func (t *TenantToUpdate) AllowPasswordSignUp(allow bool) *TenantToUpdate { return t.set(allowPasswordSignUpKey, allow) } // EnableEmailLinkSignIn enables or disables email link sign-in. // // Disabling this makes the password required for email sign-in. func (t *TenantToUpdate) EnableEmailLinkSignIn(enable bool) *TenantToUpdate { return t.set(enableEmailLinkSignInKey, enable) } // EnableAnonymousUsers enables or disables anonymous authentication. func (t *TenantToUpdate) EnableAnonymousUsers(enable bool) *TenantToUpdate { return t.set(enableAnonymousUser, enable) } // MultiFactorConfig configures the tenant's multi-factor settings func (t *TenantToUpdate) MultiFactorConfig(multiFactorConfig MultiFactorConfig) *TenantToUpdate { return t.set(multiFactorConfigTenantKey, multiFactorConfig) } func (t *TenantToUpdate) set(key string, value interface{}) *TenantToUpdate { if t.params == nil { t.params = make(nestedMap) } t.params.Set(key, value) return t } func (t *TenantToUpdate) validate() error { req := make(map[string]interface{}) for k, v := range t.params { req[k] = v } val, ok := req[multiFactorConfigTenantKey] if ok { multiFactorConfig, ok := val.(MultiFactorConfig) if !ok { return fmt.Errorf("invalid type for MultiFactorConfig: %s", req[multiFactorConfigProjectKey]) } if err := multiFactorConfig.validate(); err != nil { return err } } return nil } // TenantIterator is an iterator over tenants. type TenantIterator struct { tm *TenantManager ctx context.Context nextFunc func() error pageInfo *iterator.PageInfo tenants []*Tenant } // PageInfo supports pagination. func (it *TenantIterator) PageInfo() *iterator.PageInfo { return it.pageInfo } // Next returns the next Tenant. The error value of [iterator.Done] is // returned if there are no more results. Once Next returns [iterator.Done], all // subsequent calls will return [iterator.Done]. func (it *TenantIterator) Next() (*Tenant, error) { if err := it.nextFunc(); err != nil { return nil, err } tenant := it.tenants[0] it.tenants = it.tenants[1:] return tenant, nil } func (it *TenantIterator) fetch(pageSize int, pageToken string) (string, error) { params := map[string]string{ "pageSize": strconv.Itoa(pageSize), } if pageToken != "" { params["pageToken"] = pageToken } req := &internal.Request{ Method: http.MethodGet, URL: "/tenants", Opts: []internal.HTTPOption{ internal.WithQueryParams(params), }, } var result struct { Tenants []Tenant `json:"tenants"` NextPageToken string `json:"nextPageToken"` } if _, err := it.tm.makeRequest(it.ctx, req, &result); err != nil { return "", err } for i := range result.Tenants { result.Tenants[i].ID = extractResourceID(result.Tenants[i].ID) it.tenants = append(it.tenants, &result.Tenants[i]) } it.pageInfo.Token = result.NextPageToken return result.NextPageToken, nil } golang-google-firebase-go-4.18.0/auth/tenant_mgt_test.go000066400000000000000000001323461505612111400232020ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/json" "fmt" "io/ioutil" "net/http" "reflect" "sort" "strconv" "strings" "testing" "time" "firebase.google.com/go/v4/errorutils" "google.golang.org/api/iterator" ) func TestAuthForTenantEmptyTenantID(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("") if client != nil || err == nil { t.Errorf("AuthForTenant() = (%v, %v); want = (nil, error)", client, err) } } func TestTenantID(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } const want = "tenantID" tenantID := client.TenantID() if tenantID != want { t.Errorf("TenantID() = %q; want = %q", tenantID, want) } if client.baseClient.tenantID != want { t.Errorf("baseClient.tenantID = %q; want = %q", client.baseClient.tenantID, want) } } func TestTenantGetUser(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } user, err := client.GetUser(context.Background(), "ignored_id") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user, testUser) { t.Errorf("GetUser() = %#v; want = %#v", user, testUser) } want := `{"localId":["ignored_id"]}` got := string(s.Rbody) if got != want { t.Errorf("GetUser() Req = %v; want = %v", got, want) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUser() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestTenantGetUserByEmail(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } user, err := client.GetUserByEmail(context.Background(), "test@email.com") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user, testUser) { t.Errorf("GetUserByEmail() = %#v; want = %#v", user, testUser) } want := `{"email":["test@email.com"]}` got := string(s.Rbody) if got != want { t.Errorf("GetUserByEmail() Req = %v; want = %v", got, want) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUserByEmail() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestTenantGetUserByPhoneNumber(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } user, err := client.GetUserByPhoneNumber(context.Background(), "+1234567890") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user, testUser) { t.Errorf("GetUserByPhoneNumber() = %#v; want = %#v", user, testUser) } want := `{"phoneNumber":["+1234567890"]}` got := string(s.Rbody) if got != want { t.Errorf("GetUserByPhoneNumber() Req = %v; want = %v", got, want) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUserByPhoneNumber() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestTenantListUsers(t *testing.T) { testListUsersResponse, err := ioutil.ReadFile("../testdata/list_users.json") if err != nil { t.Fatal(err) } s := echoServer(testListUsersResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } want := []*ExportedUserRecord{ {UserRecord: testUser, PasswordHash: "passwordhash1", PasswordSalt: "salt1"}, {UserRecord: testUser, PasswordHash: "passwordhash2", PasswordSalt: "salt2"}, {UserRecord: testUserWithoutMFA, PasswordHash: "passwordhash3", PasswordSalt: "salt3"}, } testIterator := func(iter *UserIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { user, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user.UserRecord, want[i].UserRecord) { t.Errorf("Users(%q) = %#v; want = %#v", token, user, want[i]) } if user.PasswordHash != want[i].PasswordHash { t.Errorf("Users(%q) PasswordHash = %q; want = %q", token, user.PasswordHash, want[i].PasswordHash) } if user.PasswordSalt != want[i].PasswordSalt { t.Errorf("Users(%q) PasswordSalt = %q; want = %q", token, user.PasswordSalt, want[i].PasswordSalt) } count++ } if count != len(want) { t.Errorf("Users(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("Users(%q) = %v, want = %v", token, err, iterator.Done) } hr := s.Req[len(s.Req)-1] // Check the query string of the last HTTP request made. gotReq := hr.URL.Query().Encode() if gotReq != req { t.Errorf("Users(%q) = %q, want = %v", token, gotReq, req) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:batchGet" if hr.URL.Path != wantPath { t.Errorf("Users(%q) URL = %q; want = %q", token, hr.URL.Path, wantPath) } } testIterator( client.Users(context.Background(), ""), "", "maxResults=1000") testIterator( client.Users(context.Background(), "pageToken"), "pageToken", "maxResults=1000&nextPageToken=pageToken") } func TestTenantCreateUser(t *testing.T) { resp := `{ "kind": "identitytoolkit#SignupNewUserResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts" for _, tc := range createUserCases { uid, err := client.createUser(context.Background(), tc.params) if uid != "expectedUserID" || err != nil { t.Errorf("createUser(%#v) = (%q, %v); want = (%q, nil)", tc.params, uid, err, "expectedUserID") } want, err := json.Marshal(tc.req) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(s.Rbody, want) { t.Errorf("createUser(%#v) request = %v; want = %v", tc.params, string(s.Rbody), string(want)) } if s.Req[0].RequestURI != wantPath { t.Errorf("createUser(%#v) URL = %q; want = %q", tc.params, s.Req[0].RequestURI, wantPath) } } } func TestTenantUpdateUser(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:update" for _, tc := range updateUserCases { err := client.updateUser(context.Background(), "uid", tc.params) if err != nil { t.Errorf("updateUser(%v) = %v; want = nil", tc.params, err) } tc.req["localId"] = "uid" want, err := json.Marshal(tc.req) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(s.Rbody, want) { t.Errorf("updateUser() request = %v; want = %v", string(s.Rbody), string(want)) } if s.Req[0].RequestURI != wantPath { t.Errorf("updateUser(%#v) URL = %q; want = %q", tc.params, s.Req[0].RequestURI, wantPath) } } } func TestTenantRevokeRefreshTokens(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() tc, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } before := time.Now().Unix() if err := tc.RevokeRefreshTokens(context.Background(), "some_uid"); err != nil { t.Error(err) } after := time.Now().Unix() var req struct { ValidSince string `json:"validSince"` } if err := json.Unmarshal(s.Rbody, &req); err != nil { t.Fatal(err) } validSince, err := strconv.ParseInt(req.ValidSince, 10, 64) if err != nil { t.Fatal(err) } if validSince > after || validSince < before { t.Errorf("validSince = %d, expecting time between %d and %d", validSince, before, after) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:update" if s.Req[0].RequestURI != wantPath { t.Errorf("RevokeRefreshTokens() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestTenantSetCustomUserClaims(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "uid" }` s := echoServer([]byte(resp), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:update" for _, tc := range setCustomUserClaimsCases { err := client.SetCustomUserClaims(context.Background(), "uid", tc) if err != nil { t.Errorf("SetCustomUserClaims(%v) = %v; want nil", tc, err) } input := tc if input == nil { input = map[string]interface{}{} } b, err := json.Marshal(input) if err != nil { t.Fatal(err) } m := map[string]interface{}{ "localId": "uid", "customAttributes": string(b), } want, err := json.Marshal(m) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(s.Rbody, want) { t.Errorf("SetCustomUserClaims() = %v; want = %v", string(s.Rbody), string(want)) } hr := s.Req[len(s.Req)-1] if hr.RequestURI != wantPath { t.Errorf("RevokeRefreshTokens() URL = %q; want = %q", hr.RequestURI, wantPath) } } } func TestTenantImportUsers(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } users := []*UserToImport{ (&UserToImport{}).UID("user1"), (&UserToImport{}).UID("user2"), } result, err := client.ImportUsers(context.Background(), users) if err != nil { t.Fatal(err) } if result.SuccessCount != 2 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 2, FailureCount: 0}", result) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:batchCreate" if s.Req[0].RequestURI != wantPath { t.Errorf("ImportUsers() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestTenantImportUsersWithHash(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } users := []*UserToImport{ (&UserToImport{}).UID("user1").PasswordHash([]byte("password")), (&UserToImport{}).UID("user2"), } result, err := client.ImportUsers(context.Background(), users, WithHash(mockHash{ key: "key", saltSep: ",", rounds: 8, memoryCost: 14, })) if err != nil { t.Fatal(err) } if result.SuccessCount != 2 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 2, FailureCount: 0}", result) } var got map[string]interface{} if err := json.Unmarshal(s.Rbody, &got); err != nil { t.Fatal(err) } want := map[string]interface{}{ "hashAlgorithm": "MOCKHASH", "signerKey": "key", "saltSeparator": ",", "rounds": float64(8), "memoryCost": float64(14), } for k, v := range want { gv, ok := got[k] if !ok || gv != v { t.Errorf("ImportUsers() request(%q) = %v; want = %v", k, gv, v) } } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:batchCreate" if s.Req[0].RequestURI != wantPath { t.Errorf("ImportUsers() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestTenantDeleteUser(t *testing.T) { resp := `{ "kind": "identitytoolkit#SignupNewUserResponse", "email": "", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } if err := client.DeleteUser(context.Background(), "uid"); err != nil { t.Errorf("DeleteUser() = %v; want = nil", err) } wantPath := "/projects/mock-project-id/tenants/tenantID/accounts:delete" if s.Req[0].RequestURI != wantPath { t.Errorf("DeleteUser() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } const wantEmailActionURL = "/projects/mock-project-id/tenants/tenantID/accounts:sendOobCode" func TestTenantEmailVerificationLink(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } link, err := client.EmailVerificationLink(context.Background(), testEmail) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("EmailVerificationLink() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "VERIFY_EMAIL", "email": testEmail, "returnOobLink": true, } if err := checkActionLinkRequestWithURL(want, wantEmailActionURL, s); err != nil { t.Fatalf("EmailVerificationLink() %v", err) } } func TestTenantPasswordResetLink(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } link, err := client.PasswordResetLink(context.Background(), testEmail) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("PasswordResetLink() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "PASSWORD_RESET", "email": testEmail, "returnOobLink": true, } if err := checkActionLinkRequestWithURL(want, wantEmailActionURL, s); err != nil { t.Fatalf("PasswordResetLink() %v", err) } } func TestTenantEmailSignInLink(t *testing.T) { s := echoServer(testActionLinkResponse, t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } link, err := client.EmailSignInLink(context.Background(), testEmail, testActionCodeSettings) if err != nil { t.Fatal(err) } if link != testActionLink { t.Errorf("EmailSignInLink() = %q; want = %q", link, testActionLink) } want := map[string]interface{}{ "requestType": "EMAIL_SIGNIN", "email": testEmail, "returnOobLink": true, } for k, v := range testActionCodeSettingsMap { want[k] = v } if err := checkActionLinkRequestWithURL(want, wantEmailActionURL, s); err != nil { t.Fatalf("EmailSignInLink() %v", err) } } func TestTenantOIDCProviderConfig(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } oidc, err := client.OIDCProviderConfig(context.Background(), "oidc.provider") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("OIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } req := s.Req[0] if req.Method != http.MethodGet { t.Errorf("OIDCProviderConfig() Method = %q; want = %q", req.Method, http.MethodGet) } wantURL := "/projects/mock-project-id/tenants/tenantID/oauthIdpConfigs/oidc.provider" if req.URL.Path != wantURL { t.Errorf("OIDCProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestTenantCreateOIDCProviderConfig(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } options := (&OIDCProviderConfigToCreate{}). ID(oidcProviderConfig.ID). DisplayName(oidcProviderConfig.DisplayName). Enabled(oidcProviderConfig.Enabled). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer) oidc, err := client.CreateOIDCProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("CreateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": oidcProviderConfig.DisplayName, "enabled": oidcProviderConfig.Enabled, "clientId": oidcProviderConfig.ClientID, "issuer": oidcProviderConfig.Issuer, } wantURL := "/projects/mock-project-id/tenants/tenantID/oauthIdpConfigs" if err := checkCreateOIDCConfigRequestWithURL(s, wantBody, wantURL); err != nil { t.Fatal(err) } } func TestTenantUpdateOIDCProviderConfig(t *testing.T) { s := echoServer([]byte(oidcConfigResponse), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } options := (&OIDCProviderConfigToUpdate{}). DisplayName(oidcProviderConfig.DisplayName). Enabled(oidcProviderConfig.Enabled). ClientID(oidcProviderConfig.ClientID). Issuer(oidcProviderConfig.Issuer) oidc, err := client.UpdateOIDCProviderConfig(context.Background(), "oidc.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(oidc, oidcProviderConfig) { t.Errorf("UpdateOIDCProviderConfig() = %#v; want = %#v", oidc, oidcProviderConfig) } wantBody := map[string]interface{}{ "displayName": oidcProviderConfig.DisplayName, "enabled": oidcProviderConfig.Enabled, "clientId": oidcProviderConfig.ClientID, "issuer": oidcProviderConfig.Issuer, } wantMask := []string{ "clientId", "displayName", "enabled", "issuer", } wantURL := "/projects/mock-project-id/tenants/tenantID/oauthIdpConfigs/oidc.provider" if err := checkUpdateOIDCConfigRequestWithURL(s, wantBody, wantMask, wantURL); err != nil { t.Fatal(err) } } func TestTenantDeleteOIDCProviderConfig(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } if err := client.DeleteOIDCProviderConfig(context.Background(), "oidc.provider"); err != nil { t.Fatal(err) } req := s.Req[0] if req.Method != http.MethodDelete { t.Errorf("DeleteOIDCProviderConfig() Method = %q; want = %q", req.Method, http.MethodDelete) } wantURL := "/projects/mock-project-id/tenants/tenantID/oauthIdpConfigs/oidc.provider" if req.URL.Path != wantURL { t.Errorf("DeleteOIDCProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestTenantOIDCProviderConfigs(t *testing.T) { template := `{ "oauthIdpConfigs": [ %s, %s, %s ], "nextPageToken": "" }` response := fmt.Sprintf(template, oidcConfigResponse, oidcConfigResponse, oidcConfigResponse) s := echoServer([]byte(response), t) defer s.Close() want := []*OIDCProviderConfig{ oidcProviderConfig, oidcProviderConfig, oidcProviderConfig, } wantPath := "/projects/mock-project-id/tenants/tenantID/oauthIdpConfigs" testIterator := func(iter *OIDCProviderConfigIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { config, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(config, want[i]) { t.Errorf("OIDCProviderConfigs(%q) = %#v; want = %#v", token, config, want[i]) } count++ } if count != len(want) { t.Errorf("OIDCProviderConfigs(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("OIDCProviderConfigs(%q) = %v; want = %v", token, err, iterator.Done) } url := s.Req[len(s.Req)-1].URL if url.Path != wantPath { t.Errorf("OIDCProviderConfigs(%q) = %q; want = %q", token, url.Path, wantPath) } // Check the query string of the last HTTP request made. gotReq := url.Query().Encode() if gotReq != req { t.Errorf("OIDCProviderConfigs(%q) = %q; want = %v", token, gotReq, req) } } client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } testIterator( client.OIDCProviderConfigs(context.Background(), ""), "", "pageSize=100") testIterator( client.OIDCProviderConfigs(context.Background(), "pageToken"), "pageToken", "pageSize=100&pageToken=pageToken") } func TestTenantSAMLProviderConfig(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } saml, err := client.SAMLProviderConfig(context.Background(), "saml.provider") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("SAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } req := s.Req[0] if req.Method != http.MethodGet { t.Errorf("SAMLProviderConfig() Method = %q; want = %q", req.Method, http.MethodGet) } wantURL := "/projects/mock-project-id/tenants/tenantID/inboundSamlConfigs/saml.provider" if req.URL.Path != wantURL { t.Errorf("SAMLProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestTenantCreateSAMLProviderConfig(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } options := (&SAMLProviderConfigToCreate{}). ID(samlProviderConfig.ID). DisplayName(samlProviderConfig.DisplayName). Enabled(samlProviderConfig.Enabled). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). RequestSigningEnabled(samlProviderConfig.RequestSigningEnabled). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.CreateSAMLProviderConfig(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("CreateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": samlProviderConfig.DisplayName, "enabled": samlProviderConfig.Enabled, "idpConfig": map[string]interface{}{ "idpEntityId": samlProviderConfig.IDPEntityID, "ssoUrl": samlProviderConfig.SSOURL, "signRequest": samlProviderConfig.RequestSigningEnabled, "idpCertificates": idpCertsMap, }, "spConfig": map[string]interface{}{ "spEntityId": samlProviderConfig.RPEntityID, "callbackUri": samlProviderConfig.CallbackURL, }, } wantURL := "/projects/mock-project-id/tenants/tenantID/inboundSamlConfigs" if err := checkCreateSAMLConfigRequestWithURL(s, wantBody, wantURL); err != nil { t.Fatal(err) } } func TestTenantUpdateSAMLProviderConfig(t *testing.T) { s := echoServer([]byte(samlConfigResponse), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } options := (&SAMLProviderConfigToUpdate{}). DisplayName(samlProviderConfig.DisplayName). Enabled(samlProviderConfig.Enabled). IDPEntityID(samlProviderConfig.IDPEntityID). SSOURL(samlProviderConfig.SSOURL). RequestSigningEnabled(samlProviderConfig.RequestSigningEnabled). X509Certificates(samlProviderConfig.X509Certificates). RPEntityID(samlProviderConfig.RPEntityID). CallbackURL(samlProviderConfig.CallbackURL) saml, err := client.UpdateSAMLProviderConfig(context.Background(), "saml.provider", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(saml, samlProviderConfig) { t.Errorf("UpdateSAMLProviderConfig() = %#v; want = %#v", saml, samlProviderConfig) } wantBody := map[string]interface{}{ "displayName": samlProviderConfig.DisplayName, "enabled": samlProviderConfig.Enabled, "idpConfig": map[string]interface{}{ "idpEntityId": samlProviderConfig.IDPEntityID, "ssoUrl": samlProviderConfig.SSOURL, "signRequest": samlProviderConfig.RequestSigningEnabled, "idpCertificates": idpCertsMap, }, "spConfig": map[string]interface{}{ "spEntityId": samlProviderConfig.RPEntityID, "callbackUri": samlProviderConfig.CallbackURL, }, } wantMask := []string{ "displayName", "enabled", "idpConfig.idpCertificates", "idpConfig.idpEntityId", "idpConfig.signRequest", "idpConfig.ssoUrl", "spConfig.callbackUri", "spConfig.spEntityId", } wantURL := "/projects/mock-project-id/tenants/tenantID/inboundSamlConfigs/saml.provider" if err := checkUpdateSAMLConfigRequestWithURL(s, wantBody, wantMask, wantURL); err != nil { t.Fatal(err) } } func TestTenantDeleteSAMLProviderConfig(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } if err := client.DeleteSAMLProviderConfig(context.Background(), "saml.provider"); err != nil { t.Fatal(err) } req := s.Req[0] if req.Method != http.MethodDelete { t.Errorf("DeleteSAMLProviderConfig() Method = %q; want = %q", req.Method, http.MethodDelete) } wantURL := "/projects/mock-project-id/tenants/tenantID/inboundSamlConfigs/saml.provider" if req.URL.Path != wantURL { t.Errorf("DeleteSAMLProviderConfig() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestTenantSAMLProviderConfigs(t *testing.T) { template := `{ "inboundSamlConfigs": [ %s, %s, %s ], "nextPageToken": "" }` response := fmt.Sprintf(template, samlConfigResponse, samlConfigResponse, samlConfigResponse) s := echoServer([]byte(response), t) defer s.Close() want := []*SAMLProviderConfig{ samlProviderConfig, samlProviderConfig, samlProviderConfig, } wantPath := "/projects/mock-project-id/tenants/tenantID/inboundSamlConfigs" testIterator := func(iter *SAMLProviderConfigIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { config, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(config, want[i]) { t.Errorf("SAMLProviderConfigs(%q) = %#v; want = %#v", token, config, want[i]) } count++ } if count != len(want) { t.Errorf("SAMLProviderConfigs(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("SAMLProviderConfigs(%q) = %v; want = %v", token, err, iterator.Done) } url := s.Req[len(s.Req)-1].URL if url.Path != wantPath { t.Errorf("SAMLProviderConfigs(%q) = %q; want = %q", token, url.Path, wantPath) } // Check the query string of the last HTTP request made. gotReq := url.Query().Encode() if gotReq != req { t.Errorf("SAMLProviderConfigs(%q) = %q; want = %v", token, gotReq, req) } } client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } testIterator( client.SAMLProviderConfigs(context.Background(), ""), "", "pageSize=100") testIterator( client.SAMLProviderConfigs(context.Background(), "pageToken"), "pageToken", "pageSize=100&pageToken=pageToken") } func TestTenantVerifyIDToken(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.TenantManager.base.idTokenVerifier = testIDTokenVerifier client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } idToken := getIDToken(mockIDTokenPayload{ "firebase": map[string]interface{}{ "tenant": "tenantID", "sign_in_provider": "custom", }, }) ft, err := client.VerifyIDToken(context.Background(), idToken) if err != nil { t.Fatal(err) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "tenantID" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "tenantID") } } func TestTenantVerifyIDTokenAndCheckRevoked(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.TenantManager.base.idTokenVerifier = testIDTokenVerifier client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } idToken := getIDToken(mockIDTokenPayload{ "firebase": map[string]interface{}{ "tenant": "tenantID", "sign_in_provider": "custom", }, }) ft, err := client.VerifyIDTokenAndCheckRevoked(context.Background(), idToken) if err != nil { t.Fatal(err) } if ft.Firebase.SignInProvider != "custom" { t.Errorf("SignInProvider = %q; want = %q", ft.Firebase.SignInProvider, "custom") } if ft.Firebase.Tenant != "tenantID" { t.Errorf("Tenant = %q; want = %q", ft.Firebase.Tenant, "tenantID") } wantURI := "/projects/mock-project-id/tenants/tenantID/accounts:lookup" if s.Req[0].RequestURI != wantURI { t.Errorf("VerifySessionCookieAndCheckRevoked() URL = %q; want = %q", s.Req[0].RequestURI, wantURI) } } func TestInvalidTenantVerifyIDToken(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() s.Client.TenantManager.base.idTokenVerifier = testIDTokenVerifier client, err := s.Client.TenantManager.AuthForTenant("tenantID") if err != nil { t.Fatalf("AuthForTenant() = %v", err) } idToken := getIDToken(mockIDTokenPayload{ "firebase": map[string]interface{}{ "tenant": "invalidTenantID", "sign_in_provider": "custom", }, }) ft, err := client.VerifyIDToken(context.Background(), idToken) if ft != nil || err == nil || !IsTenantIDMismatch(err) { t.Errorf("VerifyIDToken() = (%v, %v); want = (nil, %q)", ft, err, tenantIDMismatch) } } const tenantResponse = `{ "name":"projects/mock-project-id/tenants/tenantID", "displayName": "Test Tenant", "allowPasswordSignup": true, "enableEmailLinkSignin": true, "enableAnonymousUser": true, "mfaConfig": { "providerConfigs": [ { "state":"ENABLED", "totpProviderConfig":{ "adjacentIntervals":5 } } ] } }` const tenantResponse2 = `{ "name":"projects/mock-project-id/tenants/tenantID2", "displayName": "Test Tenant 2", "allowPasswordSignup": true, "enableEmailLinkSignin": true, "enableAnonymousUser": true, "mfaConfig": { "providerConfigs": [ { "state":"ENABLED", "totpProviderConfig":{ "adjacentIntervals":5 } } ] } }` const tenantNotFoundResponse = `{ "error": { "message": "TENANT_NOT_FOUND" } }` var testTenant = &Tenant{ ID: "tenantID", DisplayName: "Test Tenant", AllowPasswordSignUp: true, EnableEmailLinkSignIn: true, EnableAnonymousUsers: true, MultiFactorConfig: &MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{ { State: Enabled, TOTPProviderConfig: &TOTPProviderConfig{ AdjacentIntervals: 5, }, }, }, }, } var testTenant2 = &Tenant{ ID: "tenantID2", DisplayName: "Test Tenant 2", AllowPasswordSignUp: true, EnableEmailLinkSignIn: true, EnableAnonymousUsers: true, MultiFactorConfig: &MultiFactorConfig{ ProviderConfigs: []*ProviderConfig{ { State: Enabled, TOTPProviderConfig: &TOTPProviderConfig{ AdjacentIntervals: 5, }, }, }, }, } func TestTenant(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client tenant, err := client.TenantManager.Tenant(context.Background(), "tenantID") if err != nil { t.Fatalf("Tenant() = %v", err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("Tenant() = %#v; want = %#v", tenant, testTenant) } req := s.Req[0] if req.Method != http.MethodGet { t.Errorf("Tenant() Method = %q; want = %q", req.Method, http.MethodGet) } wantURL := "/projects/mock-project-id/tenants/tenantID" if req.URL.Path != wantURL { t.Errorf("Tenant() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestTenantEmptyID(t *testing.T) { tm := &TenantManager{} wantErr := "tenantID must not be empty" tenant, err := tm.Tenant(context.Background(), "") if tenant != nil || err == nil || err.Error() != wantErr { t.Errorf("Tenant('') = (%v, %v); want = (nil, %q)", tenant, err, wantErr) } } func TestTenantError(t *testing.T) { s := echoServer([]byte(tenantNotFoundResponse), t) defer s.Close() s.Status = http.StatusNotFound client := s.Client tenant, err := client.TenantManager.Tenant(context.Background(), "tenantID") if tenant != nil || err == nil || !IsTenantNotFound(err) { t.Errorf("Tenant() = (%v, %v); want = (nil, TenantNotFound)", tenant, err) } } func TestTenantNoProjectID(t *testing.T) { tm := &TenantManager{} want := "project id not available" if _, err := tm.Tenant(context.Background(), "tenantID"); err == nil || err.Error() != want { t.Errorf("Tenant() = %v; want = %q", err, want) } } func TestCreateTenant(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client options := (&TenantToCreate{}). DisplayName(testTenant.DisplayName). AllowPasswordSignUp(testTenant.AllowPasswordSignUp). EnableEmailLinkSignIn(testTenant.EnableEmailLinkSignIn). EnableAnonymousUsers(testTenant.EnableAnonymousUsers). MultiFactorConfig(*testTenant.MultiFactorConfig) tenant, err := client.TenantManager.CreateTenant(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("CreateTenant() = %#v; want = %#v", tenant, testTenant) } wantBody := map[string]interface{}{ "displayName": testTenant.DisplayName, "allowPasswordSignup": testTenant.AllowPasswordSignUp, "enableEmailLinkSignin": testTenant.EnableEmailLinkSignIn, "enableAnonymousUser": testTenant.EnableAnonymousUsers, "mfaConfig": map[string]interface{}{ "providerConfigs": []interface{}{ map[string]interface{}{ "state": "ENABLED", "totpProviderConfig": map[string]interface{}{ "adjacentIntervals": float64(5), }, }, }, }, } if err := checkCreateTenantRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateTenantMinimal(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client tenant, err := client.TenantManager.CreateTenant(context.Background(), &TenantToCreate{}) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("CreateTenant() = %#v; want = %#v", tenant, testTenant) } wantBody := map[string]interface{}{} if err := checkCreateTenantRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateTenantZeroValues(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client options := (&TenantToCreate{}). DisplayName(""). AllowPasswordSignUp(false). EnableEmailLinkSignIn(false). EnableAnonymousUsers(false) tenant, err := client.TenantManager.CreateTenant(context.Background(), options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("CreateTenant() = %#v; want = %#v", tenant, testTenant) } wantBody := map[string]interface{}{ "displayName": "", "allowPasswordSignup": false, "enableEmailLinkSignin": false, "enableAnonymousUser": false, } if err := checkCreateTenantRequest(s, wantBody); err != nil { t.Fatal(err) } } func TestCreateTenantError(t *testing.T) { s := echoServer([]byte("{}"), t) s.Status = http.StatusInternalServerError defer s.Close() client := s.Client client.TenantManager.httpClient.RetryConfig = nil tenant, err := client.TenantManager.CreateTenant(context.Background(), &TenantToCreate{}) if tenant != nil || !errorutils.IsInternal(err) { t.Errorf("CreateTenant() = (%v, %v); want = (nil, %q)", tenant, err, "internal-error") } } func TestCreateTenantNilOptions(t *testing.T) { tm := &TenantManager{} want := "tenant must not be nil" if _, err := tm.CreateTenant(context.Background(), nil); err == nil || err.Error() != want { t.Errorf("CreateTenant(nil) = %v, want = %q", err, want) } } func TestUpdateTenant(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client options := (&TenantToUpdate{}). DisplayName(testTenant.DisplayName). AllowPasswordSignUp(testTenant.AllowPasswordSignUp). EnableEmailLinkSignIn(testTenant.EnableEmailLinkSignIn). EnableAnonymousUsers(testTenant.EnableAnonymousUsers). MultiFactorConfig(*testTenant.MultiFactorConfig) tenant, err := client.TenantManager.UpdateTenant(context.Background(), "tenantID", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("UpdateTenant() = %#v; want = %#v", tenant, testTenant) } wantBody := map[string]interface{}{ "displayName": testTenant.DisplayName, "allowPasswordSignup": testTenant.AllowPasswordSignUp, "enableEmailLinkSignin": testTenant.EnableEmailLinkSignIn, "enableAnonymousUser": testTenant.EnableAnonymousUsers, "mfaConfig": map[string]interface{}{ "providerConfigs": []interface{}{ map[string]interface{}{ "state": "ENABLED", "totpProviderConfig": map[string]interface{}{ "adjacentIntervals": float64(5), }, }, }, }, } wantMask := []string{"allowPasswordSignup", "displayName", "enableAnonymousUser", "enableEmailLinkSignin", "mfaConfig"} if err := checkUpdateTenantRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateTenantMinimal(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client options := (&TenantToUpdate{}).DisplayName(testTenant.DisplayName) tenant, err := client.TenantManager.UpdateTenant(context.Background(), "tenantID", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("UpdateTenant() = %#v; want = %#v", tenant, testTenant) } wantBody := map[string]interface{}{ "displayName": testTenant.DisplayName, } wantMask := []string{"displayName"} if err := checkUpdateTenantRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateTenantZeroValues(t *testing.T) { s := echoServer([]byte(tenantResponse), t) defer s.Close() client := s.Client options := (&TenantToUpdate{}). DisplayName(""). AllowPasswordSignUp(false). EnableEmailLinkSignIn(false). EnableAnonymousUsers(false) tenant, err := client.TenantManager.UpdateTenant(context.Background(), "tenantID", options) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, testTenant) { t.Errorf("UpdateTenant() = %#v; want = %#v", tenant, testTenant) } wantBody := map[string]interface{}{ "displayName": "", "allowPasswordSignup": false, "enableEmailLinkSignin": false, "enableAnonymousUser": false, } wantMask := []string{"allowPasswordSignup", "displayName", "enableAnonymousUser", "enableEmailLinkSignin"} if err := checkUpdateTenantRequest(s, wantBody, wantMask); err != nil { t.Fatal(err) } } func TestUpdateTenantError(t *testing.T) { s := echoServer([]byte("{}"), t) s.Status = http.StatusInternalServerError defer s.Close() client := s.Client client.TenantManager.httpClient.RetryConfig = nil options := (&TenantToUpdate{}).DisplayName("") tenant, err := client.TenantManager.UpdateTenant(context.Background(), "tenantID", options) if tenant != nil || !errorutils.IsInternal(err) { t.Errorf("UpdateTenant() = (%v, %v); want = (nil, %q)", tenant, err, "internal-error") } } func TestUpdateTenantEmptyID(t *testing.T) { tm := &TenantManager{} want := "tenantID must not be empty" options := (&TenantToUpdate{}).DisplayName("") if _, err := tm.UpdateTenant(context.Background(), "", options); err == nil || err.Error() != want { t.Errorf("UpdateTenant(nil) = %v, want = %q", err, want) } } func TestUpdateTenantNilOptions(t *testing.T) { tm := &TenantManager{} want := "tenant must not be nil" if _, err := tm.UpdateTenant(context.Background(), "tenantID", nil); err == nil || err.Error() != want { t.Errorf("UpdateTenant(nil) = %v, want = %q", err, want) } } func TestUpdateTenantEmptyOptions(t *testing.T) { tm := &TenantManager{} want := "no parameters specified in the update request" if _, err := tm.UpdateTenant(context.Background(), "tenantID", &TenantToUpdate{}); err == nil || err.Error() != want { t.Errorf("UpdateTenant({}) = %v, want = %q", err, want) } } func TestDeleteTenant(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() client := s.Client if err := client.TenantManager.DeleteTenant(context.Background(), "tenantID"); err != nil { t.Fatalf("DeleteTenant() = %v", err) } req := s.Req[0] if req.Method != http.MethodDelete { t.Errorf("DeleteTenant() Method = %q; want = %q", req.Method, http.MethodDelete) } wantURL := "/projects/mock-project-id/tenants/tenantID" if req.URL.Path != wantURL { t.Errorf("DeleteTenant() URL = %q; want = %q", req.URL.Path, wantURL) } } func TestDeleteTenantEmptyID(t *testing.T) { tm := &TenantManager{} wantErr := "tenantID must not be empty" err := tm.DeleteTenant(context.Background(), "") if err == nil || err.Error() != wantErr { t.Errorf("DeleteTenant('') = %v; want = (nil, %q)", err, wantErr) } } func TestDeleteTenantError(t *testing.T) { s := echoServer([]byte(tenantNotFoundResponse), t) defer s.Close() s.Status = http.StatusNotFound client := s.Client err := client.TenantManager.DeleteTenant(context.Background(), "tenantID") if err == nil || !IsTenantNotFound(err) { t.Errorf("DeleteTenant() = %v; want = TenantNotFound", err) } } func TestTenants(t *testing.T) { template := `{ "tenants": [ %s, %s, %s ], "nextPageToken": "" }` response := fmt.Sprintf(template, tenantResponse, tenantResponse2, tenantResponse) s := echoServer([]byte(response), t) defer s.Close() want := []*Tenant{ testTenant, testTenant2, testTenant, } wantPath := "/projects/mock-project-id/tenants" testIterator := func(iter *TenantIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { tenant, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(tenant, want[i]) { t.Errorf("Tenants(%q) = %#v; want = %#v", token, tenant, want[i]) } count++ } if count != len(want) { t.Errorf("Tenants(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("Tenants(%q) = %v; want = %v", token, err, iterator.Done) } url := s.Req[len(s.Req)-1].URL if url.Path != wantPath { t.Errorf("Tenants(%q) = %q; want = %q", token, url.Path, wantPath) } // Check the query string of the last HTTP request made. gotReq := url.Query().Encode() if gotReq != req { t.Errorf("Tenants(%q) = %q; want = %v", token, gotReq, req) } } client := s.Client testIterator( client.TenantManager.Tenants(context.Background(), ""), "", "pageSize=100") testIterator( client.TenantManager.Tenants(context.Background(), "pageToken"), "pageToken", "pageSize=100&pageToken=pageToken") } func TestTenantsError(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() s.Status = http.StatusInternalServerError client := s.Client client.TenantManager.httpClient.RetryConfig = nil it := client.TenantManager.Tenants(context.Background(), "") config, err := it.Next() if config != nil || !errorutils.IsInternal(err) { t.Errorf("Tenants() = (%v, %v); want = (nil, %q)", config, err, "internal-error") } } func checkCreateTenantRequest(s *mockAuthServer, wantBody interface{}) error { req := s.Req[0] if req.Method != http.MethodPost { return fmt.Errorf("CreateTenant() Method = %q; want = %q", req.Method, http.MethodPost) } wantURL := "/projects/mock-project-id/tenants" if req.URL.Path != wantURL { return fmt.Errorf("CreateTenant() URL = %q; want = %q", req.URL.Path, wantURL) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("CreateTenant() Body = %#v; want = %#v", body, wantBody) } return nil } func checkUpdateTenantRequest(s *mockAuthServer, wantBody interface{}, wantMask []string) error { req := s.Req[0] if req.Method != http.MethodPatch { return fmt.Errorf("UpdateTenant() Method = %q; want = %q", req.Method, http.MethodPatch) } wantURL := "/projects/mock-project-id/tenants/tenantID" if req.URL.Path != wantURL { return fmt.Errorf("UpdateTenant() URL = %q; want = %q", req.URL.Path, wantURL) } queryParam := req.URL.Query().Get("updateMask") mask := strings.Split(queryParam, ",") sort.Strings(mask) if !reflect.DeepEqual(mask, wantMask) { return fmt.Errorf("UpdateTenant() Query = %#v; want = %#v", mask, wantMask) } var body map[string]interface{} if err := json.Unmarshal(s.Rbody, &body); err != nil { return err } if !reflect.DeepEqual(body, wantBody) { return fmt.Errorf("UpdateTenant() Body = %#v; want = %#v", body, wantBody) } return nil } golang-google-firebase-go-4.18.0/auth/token_generator.go000066400000000000000000000173621505612111400231710ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "errors" "fmt" "net/http" "strings" "sync" "firebase.google.com/go/v4/internal" ) const ( algorithmNone = "none" algorithmRS256 = "RS256" emulatorEmail = "firebase-auth-emulator@example.com" ) type jwtHeader struct { Algorithm string `json:"alg"` Type string `json:"typ"` KeyID string `json:"kid,omitempty"` } type customToken struct { Iss string `json:"iss"` Aud string `json:"aud"` Exp int64 `json:"exp"` Iat int64 `json:"iat"` Sub string `json:"sub,omitempty"` UID string `json:"uid,omitempty"` TenantID string `json:"tenant_id,omitempty"` Claims map[string]interface{} `json:"claims,omitempty"` } type jwtInfo struct { header jwtHeader payload interface{} } // Token encodes the data in the jwtInfo into a signed JSON web token. func (info *jwtInfo) Token(ctx context.Context, signer cryptoSigner) (string, error) { encode := func(i interface{}) (string, error) { b, err := json.Marshal(i) if err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } header, err := encode(info.header) if err != nil { return "", err } payload, err := encode(info.payload) if err != nil { return "", err } tokenData := fmt.Sprintf("%s.%s", header, payload) sig, err := signer.Sign(ctx, []byte(tokenData)) if err != nil { return "", err } return fmt.Sprintf("%s.%s", tokenData, base64.RawURLEncoding.EncodeToString(sig)), nil } type serviceAccount struct { PrivateKey string `json:"private_key"` ClientEmail string `json:"client_email"` } // cryptoSigner is used to cryptographically sign data, and query the identity of the signer. type cryptoSigner interface { Algorithm() string Sign(context.Context, []byte) ([]byte, error) Email(context.Context) (string, error) } // serviceAccountSigner is a cryptoSigner that signs data using service account credentials. type serviceAccountSigner struct { privateKey *rsa.PrivateKey clientEmail string } var errNotAServiceAcct = errors.New("credentials json is not a service account") func signerFromCreds(creds []byte) (cryptoSigner, error) { var sa serviceAccount if err := json.Unmarshal(creds, &sa); err != nil { return nil, err } if sa.PrivateKey != "" && sa.ClientEmail != "" { return newServiceAccountSigner(sa) } return nil, errNotAServiceAcct } func newServiceAccountSigner(sa serviceAccount) (*serviceAccountSigner, error) { block, _ := pem.Decode([]byte(sa.PrivateKey)) if block == nil { return nil, fmt.Errorf("no private key data found in: %q", sa.PrivateKey) } parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err) } } rsaKey, ok := parsedKey.(*rsa.PrivateKey) if !ok { return nil, errors.New("private key is not an RSA key") } return &serviceAccountSigner{ privateKey: rsaKey, clientEmail: sa.ClientEmail, }, nil } func (s serviceAccountSigner) Algorithm() string { return algorithmRS256 } func (s serviceAccountSigner) Sign(ctx context.Context, b []byte) ([]byte, error) { hash := sha256.New() hash.Write(b) return rsa.SignPKCS1v15(rand.Reader, s.privateKey, crypto.SHA256, hash.Sum(nil)) } func (s serviceAccountSigner) Email(ctx context.Context) (string, error) { return s.clientEmail, nil } // iamSigner is a cryptoSigner that signs data by sending them to the IAMCredentials service. See // https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/signBlob // for details regarding the REST API. // // IAMCredentials requires the identity of a service account. This can be specified explicitly // at initialization. If not specified iamSigner attempts to discover a service account identity by // calling the local metadata service (works in environments like Google Compute Engine). type iamSigner struct { mutex *sync.Mutex httpClient *internal.HTTPClient serviceAcct string metadataHost string iamHost string } func newIAMSigner(ctx context.Context, config *internal.AuthConfig) (*iamSigner, error) { hc, _, err := internal.NewHTTPClient(ctx, config.Opts...) if err != nil { return nil, err } hc.Opts = []internal.HTTPOption{ internal.WithHeader("x-goog-api-client", internal.GetMetricsHeader(config.Version)), } return &iamSigner{ mutex: &sync.Mutex{}, httpClient: hc, serviceAcct: config.ServiceAccountID, metadataHost: "http://metadata.google.internal", iamHost: "https://iamcredentials.googleapis.com", }, nil } func (s iamSigner) Algorithm() string { return algorithmRS256 } func (s iamSigner) Sign(ctx context.Context, b []byte) ([]byte, error) { account, err := s.Email(ctx) if err != nil { return nil, err } url := fmt.Sprintf("%s/v1/projects/-/serviceAccounts/%s:signBlob", s.iamHost, account) body := map[string]interface{}{ "payload": base64.StdEncoding.EncodeToString(b), } req := &internal.Request{ Method: http.MethodPost, URL: url, Body: internal.NewJSONEntity(body), } var signResponse struct { Signature string `json:"signedBlob"` } if _, err := s.httpClient.DoAndUnmarshal(ctx, req, &signResponse); err != nil { return nil, err } return base64.StdEncoding.DecodeString(signResponse.Signature) } func (s iamSigner) Email(ctx context.Context) (string, error) { if s.serviceAcct != "" { return s.serviceAcct, nil } s.mutex.Lock() defer s.mutex.Unlock() result, err := s.callMetadataService(ctx) if err != nil { msg := "failed to determine service account: %v; initialize the SDK with service " + "account credentials or specify a service account with iam.serviceAccounts.signBlob " + "permission; refer to https://firebase.google.com/docs/auth/admin/create-custom-tokens " + "for more details on creating custom tokens" return "", fmt.Errorf(msg, err) } s.serviceAcct = result return result, nil } func (s iamSigner) callMetadataService(ctx context.Context) (string, error) { // Use the built-in default client without request authorization or retries for this call. noAuthClient := &internal.HTTPClient{ Client: http.DefaultClient, } url := fmt.Sprintf("%s/computeMetadata/v1/instance/service-accounts/default/email", s.metadataHost) req := &internal.Request{ Method: http.MethodGet, URL: url, Opts: []internal.HTTPOption{ internal.WithHeader("Metadata-Flavor", "Google"), }, } resp, err := noAuthClient.Do(ctx, req) if err != nil { return "", err } result := strings.TrimSpace(string(resp.Body)) if result == "" { return "", errors.New("unexpected response from metadata service") } return result, nil } type emulatedSigner struct{} func (s emulatedSigner) Algorithm() string { return algorithmNone } func (s emulatedSigner) Email(context.Context) (string, error) { return emulatorEmail, nil } func (s emulatedSigner) Sign(context.Context, []byte) ([]byte, error) { return []byte(""), nil } golang-google-firebase-go-4.18.0/auth/token_generator_test.go000066400000000000000000000236021505612111400242220ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/base64" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" ) func TestEncodeToken(t *testing.T) { info := &jwtInfo{ header: jwtHeader{Algorithm: "RS256", Type: "JWT"}, payload: mockIDTokenPayload{"key": "value"}, } s, err := info.Token(context.Background(), &mockSigner{}) if err != nil { t.Fatal(err) } parts := strings.Split(s, ".") if len(parts) != 3 { t.Errorf("encodeToken() = %d; want: %d", len(parts), 3) } var header jwtHeader if err := decode(parts[0], &header); err != nil { t.Fatal(err) } else if info.header != header { t.Errorf("decode(header) = %v; want = %v", header, info.header) } payload := make(mockIDTokenPayload) if err := decode(parts[1], &payload); err != nil { t.Fatal(err) } else if len(payload) != 1 || payload["key"] != "value" { t.Errorf("decode(payload) = %v; want = %v", payload, info.payload) } if sig, err := base64.RawURLEncoding.DecodeString(parts[2]); err != nil { t.Fatal(err) } else if string(sig) != "signedBlob" { t.Errorf("decode(signature) = %q; want = %q", string(sig), "signedBlob") } } func TestEncodeSignError(t *testing.T) { signer := &mockSigner{ err: errors.New("sign error"), } info := &jwtInfo{ header: jwtHeader{Algorithm: "RS256", Type: "JWT"}, payload: mockIDTokenPayload{"key": "value"}, } if s, err := info.Token(context.Background(), signer); s != "" || err != signer.err { t.Errorf("encodeToken() = (%v, %v); want = ('', %v)", s, err, signer.err) } } func TestEncodeInvalidPayload(t *testing.T) { info := &jwtInfo{ header: jwtHeader{Algorithm: "RS256", Type: "JWT"}, payload: mockIDTokenPayload{"key": func() {}}, } s, err := info.Token(context.Background(), &mockSigner{}) if s != "" || err == nil { t.Errorf("encodeToken() = (%v, %v); want = ('', error)", s, err) } } func TestServiceAccountSigner(t *testing.T) { b, err := ioutil.ReadFile("../testdata/service_account.json") if err != nil { t.Fatal(err) } var sa serviceAccount if err := json.Unmarshal(b, &sa); err != nil { t.Fatal(err) } signer, err := newServiceAccountSigner(sa) if err != nil { t.Fatal(err) } algorithm := signer.Algorithm() if algorithm != algorithmRS256 { t.Errorf("Algorithm() = %q; want = %q", algorithm, algorithmRS256) } email, err := signer.Email(context.Background()) if email != sa.ClientEmail || err != nil { t.Errorf("Email() = (%q, %v); want = (%q, nil)", email, err, sa.ClientEmail) } sign, err := signer.Sign(context.Background(), []byte("test")) if sign == nil || err != nil { t.Errorf("Sign() = (%v, %v); want = (bytes, nil)", email, err) } } func TestIAMSigner(t *testing.T) { ctx := context.Background() conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ServiceAccountID: "test-service-account", Version: testVersion, } signer, err := newIAMSigner(ctx, conf) if err != nil { t.Fatal(err) } algorithm := signer.Algorithm() if algorithm != algorithmRS256 { t.Errorf("Algorithm() = %q; want = %q", algorithm, algorithmRS256) } email, err := signer.Email(ctx) if email != conf.ServiceAccountID || err != nil { t.Errorf("Email() = (%q, %v); want = (%q, nil)", email, err, conf.ServiceAccountID) } wantSignature := "test-signature" server := iamServer(t, email, wantSignature) defer server.Close() signer.iamHost = server.URL signature, err := signer.Sign(ctx, []byte("input")) if err != nil { t.Fatal(err) } if string(signature) != wantSignature { t.Errorf("Sign() = %q; want = %q", string(signature), wantSignature) } } func TestIAMSignerHTTPError(t *testing.T) { conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ServiceAccountID: "test-service-account", Version: testVersion, } signer, err := newIAMSigner(context.Background(), conf) if err != nil { t.Fatal(err) } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() w.WriteHeader(http.StatusForbidden) w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"error": {"status": "PERMISSION_DENIED", "message": "test reason"}}`)) }) server := httptest.NewServer(handler) defer server.Close() signer.iamHost = server.URL want := "test reason" _, err = signer.Sign(context.Background(), []byte("input")) if err == nil || !errorutils.IsPermissionDenied(err) || err.Error() != want { t.Errorf("Sign() = %v; want = %q", err, want) } } func TestIAMSignerUnknownHTTPError(t *testing.T) { conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ServiceAccountID: "test-service-account", Version: testVersion, } signer, err := newIAMSigner(context.Background(), conf) if err != nil { t.Fatal(err) } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() w.WriteHeader(http.StatusForbidden) w.Header().Set("Content-Type", "application/json") w.Write([]byte(`not json`)) }) server := httptest.NewServer(handler) defer server.Close() signer.iamHost = server.URL want := "unexpected http response with status: 403\nnot json" _, err = signer.Sign(context.Background(), []byte("input")) if err == nil || !errorutils.IsPermissionDenied(err) || err.Error() != want { t.Errorf("Sign() = %v; want = %q", err, want) } } func TestIAMSignerWithMetadataService(t *testing.T) { ctx := context.Background() conf := &internal.AuthConfig{ Opts: optsWithTokenSource, Version: testVersion, } signer, err := newIAMSigner(ctx, conf) if err != nil { t.Fatal(err) } // start mock metadata service and test Email() serviceAcct := "discovered-service-account" handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() flavor := r.Header.Get("Metadata-Flavor") if flavor != "Google" { t.Errorf("Header(Metadata-Flavor) = %q; want = %q", flavor, "Google") } w.Header().Set("Content-Type", "application/text") w.Write([]byte(serviceAcct)) }) metadata := httptest.NewServer(handler) defer metadata.Close() signer.metadataHost = metadata.URL email, err := signer.Email(ctx) if email != serviceAcct || err != nil { t.Errorf("Email() = (%q, %v); want = (%q, nil)", email, err, serviceAcct) } // start mock IAM service and test Sign() wantSignature := "test-signature" server := iamServer(t, email, wantSignature) defer server.Close() signer.iamHost = server.URL signature, err := signer.Sign(ctx, []byte("input")) if err != nil { t.Fatal(err) } if string(signature) != wantSignature { t.Errorf("Sign() = %q; want = %q", string(signature), wantSignature) } } func TestIAMSignerNoMetadataService(t *testing.T) { ctx := context.Background() conf := &internal.AuthConfig{ Opts: optsWithTokenSource, Version: testVersion, } signer, err := newIAMSigner(ctx, conf) if err != nil { t.Fatal(err) } signer.metadataHost = "http://non-existing.metadata.service" want := "failed to determine service account: " _, err = signer.Email(ctx) if err == nil || !strings.HasPrefix(err.Error(), want) { t.Errorf("Email() = %v; want = %q", err, want) } _, err = signer.Sign(ctx, []byte("input")) if err == nil || !strings.HasPrefix(err.Error(), want) { t.Errorf("Sign() = %v; want = %q", err, want) } } func TestEmulatedSigner(t *testing.T) { signer := emulatedSigner{} algorithm := signer.Algorithm() if algorithm != algorithmNone { t.Errorf("Algorithm() = %q; want = %q", algorithm, algorithmNone) } email, err := signer.Email(context.Background()) if err != nil { t.Fatal(err) } if email != emulatorEmail { t.Errorf("Email() = %q; want = %q", email, emulatorEmail) } wantSignature := "" sign, err := signer.Sign(context.Background(), []byte("test")) if err != nil { t.Fatal(err) } if string(sign) != wantSignature { t.Errorf("Sign() = %q; want = %q", string(sign), wantSignature) } } type mockSigner struct { err error } func (s *mockSigner) Algorithm() string { return "" } func (s *mockSigner) Email(ctx context.Context) (string, error) { return "", nil } func (s *mockSigner) Sign(ctx context.Context, b []byte) ([]byte, error) { if s.err != nil { return nil, s.err } return []byte("signedBlob"), nil } func iamServer(t *testing.T, serviceAcct, signature string) *httptest.Server { resp := map[string]interface{}{ "signedBlob": base64.StdEncoding.EncodeToString([]byte(signature)), } wantPath := fmt.Sprintf("/v1/projects/-/serviceAccounts/%s:signBlob", serviceAcct) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() reqBody, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) } var m map[string]interface{} if err := json.Unmarshal(reqBody, &m); err != nil { t.Fatal(err) } if m["payload"] == "" { t.Fatal("payload = empty; want = non-empty") } if r.URL.Path != wantPath { t.Errorf("Path = %q; want = %q", r.URL.Path, wantPath) } xGoogAPIClientHeader := internal.GetMetricsHeader(testVersion) if h := r.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { t.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } w.Header().Set("Content-Type", "application/json") b, err := json.Marshal(resp) if err != nil { t.Fatal(err) } w.Write(b) }) return httptest.NewServer(handler) } golang-google-firebase-go-4.18.0/auth/token_verifier.go000066400000000000000000000353041505612111400230120ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "bytes" "context" "crypto" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "errors" "fmt" "io/ioutil" "net/http" "strconv" "strings" "sync" "time" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" "google.golang.org/api/transport" ) const ( idTokenCertURL = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com" idTokenIssuerPrefix = "https://securetoken.google.com/" sessionCookieCertURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/publicKeys" sessionCookieIssuerPrefix = "https://session.firebase.google.com/" clockSkewSeconds = 300 certificateFetchFailed = "CERTIFICATE_FETCH_FAILED" idTokenExpired = "ID_TOKEN_EXPIRED" idTokenInvalid = "ID_TOKEN_INVALID" sessionCookieExpired = "SESSION_COOKIE_EXPIRED" sessionCookieInvalid = "SESSION_COOKIE_INVALID" ) // IsCertificateFetchFailed checks if the given error was caused by a failure to fetch public key // certificates required to verify a JWT. func IsCertificateFetchFailed(err error) bool { return hasAuthErrorCode(err, certificateFetchFailed) } // IsIDTokenExpired checks if the given error was due to an expired ID token. // // When IsIDTokenExpired returns true, IsIDTokenInvalid is guranteed to return true. func IsIDTokenExpired(err error) bool { return hasAuthErrorCode(err, idTokenExpired) } // IsIDTokenInvalid checks if the given error was due to an invalid ID token. // // An ID token is considered invalid when it is malformed (i.e. contains incorrect data), expired // or revoked. func IsIDTokenInvalid(err error) bool { return hasAuthErrorCode(err, idTokenInvalid) || IsIDTokenExpired(err) || IsIDTokenRevoked(err) || IsUserDisabled(err) } // IsSessionCookieExpired checks if the given error was due to an expired session cookie. // // When IsSessionCookieExpired returns true, IsSessionCookieInvalid is guranteed to return true. func IsSessionCookieExpired(err error) bool { return hasAuthErrorCode(err, sessionCookieExpired) } // IsSessionCookieInvalid checks if the given error was due to an invalid session cookie. // // A session cookie is considered invalid when it is malformed (i.e. contains incorrect data), // expired or revoked. func IsSessionCookieInvalid(err error) bool { return hasAuthErrorCode(err, sessionCookieInvalid) || IsSessionCookieExpired(err) || IsSessionCookieRevoked(err) || IsUserDisabled(err) } // tokenVerifier verifies different types of Firebase token strings, including ID tokens and // session cookies. type tokenVerifier struct { shortName string articledShortName string docURL string projectID string issuerPrefix string invalidTokenCode string expiredTokenCode string keySource keySource clock internal.Clock } func newIDTokenVerifier(ctx context.Context, projectID string) (*tokenVerifier, error) { noAuthHTTPClient, _, err := transport.NewHTTPClient(ctx, option.WithoutAuthentication()) if err != nil { return nil, err } return &tokenVerifier{ shortName: "ID token", articledShortName: "an ID token", docURL: "https://firebase.google.com/docs/auth/admin/verify-id-tokens", projectID: projectID, issuerPrefix: idTokenIssuerPrefix, invalidTokenCode: idTokenInvalid, expiredTokenCode: idTokenExpired, keySource: newHTTPKeySource(idTokenCertURL, noAuthHTTPClient), clock: internal.SystemClock, }, nil } func newSessionCookieVerifier(ctx context.Context, projectID string) (*tokenVerifier, error) { noAuthHTTPClient, _, err := transport.NewHTTPClient(ctx, option.WithoutAuthentication()) if err != nil { return nil, err } return &tokenVerifier{ shortName: "session cookie", articledShortName: "a session cookie", docURL: "https://firebase.google.com/docs/auth/admin/manage-cookies", projectID: projectID, issuerPrefix: sessionCookieIssuerPrefix, invalidTokenCode: sessionCookieInvalid, expiredTokenCode: sessionCookieExpired, keySource: newHTTPKeySource(sessionCookieCertURL, noAuthHTTPClient), clock: internal.SystemClock, }, nil } // VerifyToken Verifies that the given token string is a valid Firebase JWT. // // VerifyToken considers a token string to be valid if all the following conditions are met: // - The token string is a valid RS256 JWT. // - The JWT contains a valid key ID (kid) claim. // - The JWT contains valid issuer (iss) and audience (aud) claims that match the issuerPrefix // and projectID of the tokenVerifier. // - The JWT contains a valid subject (sub) claim. // - The JWT is not expired, and it has been issued some time in the past. // - The JWT is signed by a Firebase Auth backend server as determined by the keySource. // // If any of the above conditions are not met, an error is returned. Otherwise a pointer to a // decoded Token is returned. func (tv *tokenVerifier) VerifyToken(ctx context.Context, token string, isEmulator bool) (*Token, error) { if tv.projectID == "" { // Configuration error. return nil, errors.New("project id not available") } // Validate the token content first. This is fast and cheap. payload, err := tv.verifyContent(token, isEmulator) if err != nil { return nil, err } if err := tv.verifyTimestamps(payload); err != nil { return nil, err } // In emulator mode, skip signature verification if isEmulator { return payload, nil } // Verifying the signature requires synchronized access to a key cache and // potentially issues an http request. Therefore we do it last. if err := tv.verifySignature(ctx, token); err != nil { return nil, err } return payload, nil } func (tv *tokenVerifier) verifyContent(token string, isEmulator bool) (*Token, error) { if token == "" { return nil, &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: fmt.Sprintf("%s must be a non-empty string", tv.shortName), Ext: map[string]interface{}{authErrorCode: tv.invalidTokenCode}, } } payload, err := tv.verifyHeaderAndBody(token, isEmulator) if err != nil { return nil, &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: fmt.Sprintf( "%s; see %s for details on how to retrieve a valid %s", err.Error(), tv.docURL, tv.shortName), Ext: map[string]interface{}{authErrorCode: tv.invalidTokenCode}, } } return payload, nil } func (tv *tokenVerifier) verifyTimestamps(payload *Token) error { if (payload.IssuedAt - clockSkewSeconds) > tv.clock.Now().Unix() { return &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: fmt.Sprintf("%s issued at future timestamp: %d", tv.shortName, payload.IssuedAt), Ext: map[string]interface{}{authErrorCode: tv.invalidTokenCode}, } } if (payload.Expires + clockSkewSeconds) < tv.clock.Now().Unix() { return &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: fmt.Sprintf("%s has expired at: %d", tv.shortName, payload.Expires), Ext: map[string]interface{}{authErrorCode: tv.expiredTokenCode}, } } return nil } func (tv *tokenVerifier) verifySignature(ctx context.Context, token string) error { keys, err := tv.keySource.Keys(ctx) if err != nil { return &internal.FirebaseError{ ErrorCode: internal.Unknown, String: err.Error(), Ext: map[string]interface{}{authErrorCode: certificateFetchFailed}, } } if !tv.verifySignatureWithKeys(ctx, token, keys) { return &internal.FirebaseError{ ErrorCode: internal.InvalidArgument, String: "failed to verify token signature", Ext: map[string]interface{}{authErrorCode: tv.invalidTokenCode}, } } return nil } func (tv *tokenVerifier) verifyHeaderAndBody(token string, isEmulator bool) (*Token, error) { var ( header jwtHeader payload Token ) segments := strings.Split(token, ".") if len(segments) != 3 { return nil, errors.New("incorrect number of segments") } if err := decode(segments[0], &header); err != nil { return nil, err } if err := decode(segments[1], &payload); err != nil { return nil, err } issuer := tv.issuerPrefix + tv.projectID if !isEmulator && header.KeyID == "" { if payload.Audience == firebaseAudience { return nil, fmt.Errorf("expected %s but got a custom token", tv.articledShortName) } return nil, fmt.Errorf("%s has no 'kid' header", tv.shortName) } if !isEmulator && header.Algorithm != "RS256" { return nil, fmt.Errorf("%s has invalid algorithm; expected 'RS256' but got %q", tv.shortName, header.Algorithm) } if payload.Audience != tv.projectID { return nil, fmt.Errorf("%s has invalid 'aud' (audience) claim; expected %q but got %q; %s", tv.shortName, tv.projectID, payload.Audience, tv.getProjectIDMatchMessage()) } if payload.Issuer != issuer { return nil, fmt.Errorf("%s has invalid 'iss' (issuer) claim; expected %q but got %q; %s", tv.shortName, issuer, payload.Issuer, tv.getProjectIDMatchMessage()) } if payload.Subject == "" { return nil, fmt.Errorf("%s has empty 'sub' (subject) claim", tv.shortName) } if len(payload.Subject) > 128 { return nil, fmt.Errorf("%s has a 'sub' (subject) claim longer than 128 characters", tv.shortName) } payload.UID = payload.Subject var customClaims map[string]interface{} if err := decode(segments[1], &customClaims); err != nil { return nil, err } for _, standardClaim := range []string{"iss", "aud", "exp", "iat", "sub", "uid"} { delete(customClaims, standardClaim) } payload.Claims = customClaims return &payload, nil } func (tv *tokenVerifier) verifySignatureWithKeys(ctx context.Context, token string, keys []*publicKey) bool { segments := strings.Split(token, ".") var h jwtHeader decode(segments[0], &h) verified := false for _, k := range keys { if h.KeyID == "" || h.KeyID == k.Kid { if verifyJWTSignature(segments, k) == nil { verified = true break } } } return verified } func (tv *tokenVerifier) getProjectIDMatchMessage() string { return fmt.Sprintf( "make sure the %s comes from the same Firebase project as the credential used to"+ " authenticate this SDK", tv.shortName) } // decode accepts a JWT segment, and decodes it into the given interface. func decode(segment string, i interface{}) error { decoded, err := base64.RawURLEncoding.DecodeString(segment) if err != nil { return err } return json.NewDecoder(bytes.NewBuffer(decoded)).Decode(i) } func verifyJWTSignature(parts []string, k *publicKey) error { content := parts[0] + "." + parts[1] signature, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return err } h := sha256.New() h.Write([]byte(content)) return rsa.VerifyPKCS1v15(k.Key, crypto.SHA256, h.Sum(nil), []byte(signature)) } // publicKey represents a parsed RSA public key along with its unique key ID. type publicKey struct { Kid string Key *rsa.PublicKey } // keySource is used to obtain a set of public keys, which can be used to verify cryptographic // signatures. type keySource interface { Keys(context.Context) ([]*publicKey, error) } // httpKeySource fetches RSA public keys from a remote HTTP server, and caches them in // memory. It also handles cache! invalidation and refresh based on the standard HTTP // cache-control headers. type httpKeySource struct { KeyURI string HTTPClient *http.Client CachedKeys []*publicKey ExpiryTime time.Time Clock internal.Clock Mutex *sync.Mutex } func newHTTPKeySource(uri string, hc *http.Client) *httpKeySource { return &httpKeySource{ KeyURI: uri, HTTPClient: hc, Clock: internal.SystemClock, Mutex: &sync.Mutex{}, } } // Keys returns the RSA Public Keys hosted at this key source's URI. Refreshes the data if // the cache is stale. func (k *httpKeySource) Keys(ctx context.Context) ([]*publicKey, error) { k.Mutex.Lock() defer k.Mutex.Unlock() if len(k.CachedKeys) == 0 || k.hasExpired() { err := k.refreshKeys(ctx) if err != nil && len(k.CachedKeys) == 0 { return nil, err } } return k.CachedKeys, nil } // hasExpired indicates whether the cache has expired. func (k *httpKeySource) hasExpired() bool { return k.Clock.Now().After(k.ExpiryTime) } func (k *httpKeySource) refreshKeys(ctx context.Context) error { k.CachedKeys = nil req, err := http.NewRequest(http.MethodGet, k.KeyURI, nil) if err != nil { return err } resp, err := k.HTTPClient.Do(req.WithContext(ctx)) if err != nil { return err } defer resp.Body.Close() contents, err := ioutil.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode != http.StatusOK { return fmt.Errorf("invalid response (%d) while retrieving public keys: %s", resp.StatusCode, string(contents)) } newKeys, err := parsePublicKeys(contents) if err != nil { return err } maxAge := findMaxAge(resp) k.CachedKeys = append([]*publicKey(nil), newKeys...) k.ExpiryTime = k.Clock.Now().Add(*maxAge) return nil } func parsePublicKeys(keys []byte) ([]*publicKey, error) { m := make(map[string]string) err := json.Unmarshal(keys, &m) if err != nil { return nil, err } var result []*publicKey for kid, key := range m { pubKey, err := parsePublicKey(kid, []byte(key)) if err != nil { return nil, err } result = append(result, pubKey) } return result, nil } func parsePublicKey(kid string, key []byte) (*publicKey, error) { block, _ := pem.Decode(key) if block == nil { return nil, errors.New("failed to decode the certificate as PEM") } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, err } pk, ok := cert.PublicKey.(*rsa.PublicKey) if !ok { return nil, errors.New("certificate is not an RSA key") } return &publicKey{kid, pk}, nil } func findMaxAge(resp *http.Response) *time.Duration { cc := resp.Header.Get("cache-control") for _, value := range strings.Split(cc, ",") { value = strings.TrimSpace(value) if strings.HasPrefix(value, "max-age=") { sep := strings.Index(value, "=") seconds, err := strconv.ParseInt(value[sep+1:], 10, 64) if err != nil { seconds = 0 } duration := time.Duration(seconds) * time.Second return &duration } } defaultDuration := time.Duration(0) * time.Second return &defaultDuration } golang-google-firebase-go-4.18.0/auth/token_verifier_test.go000066400000000000000000000153751505612111400240570ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "errors" "fmt" "io" "io/ioutil" "net/http" "testing" "time" "firebase.google.com/go/v4/internal" ) func TestNewIDTokenVerifier(t *testing.T) { tv, err := newIDTokenVerifier(context.Background(), testProjectID) if err != nil { t.Fatal(err) } if tv.shortName != "ID token" { t.Errorf("tokenVerifier.shortName = %q; want = %q", tv.shortName, "ID token") } if tv.projectID != testProjectID { t.Errorf("tokenVerifier.projectID = %q; want = %q", tv.projectID, testProjectID) } if tv.issuerPrefix != idTokenIssuerPrefix { t.Errorf("tokenVerifier.issuerPrefix = %q; want = %q", tv.issuerPrefix, idTokenIssuerPrefix) } ks, ok := tv.keySource.(*httpKeySource) if !ok { t.Fatalf("tokenVerifier.keySource = %#v; want = httpKeySource", tv.keySource) } if ks.KeyURI != idTokenCertURL { t.Errorf("tokenVerifier.certURL = %q; want = %q", ks.KeyURI, idTokenCertURL) } } func TestHTTPKeySource(t *testing.T) { data, err := ioutil.ReadFile("../testdata/public_certs.json") if err != nil { t.Fatal(err) } ks := newHTTPKeySource("http://mock.url", http.DefaultClient) if ks.HTTPClient == nil { t.Errorf("HTTPClient = nil; want = non-nil") } hc, rc := newTestHTTPClient(data) ks.HTTPClient = hc if err := verifyHTTPKeySource(ks, rc); err != nil { t.Fatal(err) } } func TestHTTPKeySourceWithClient(t *testing.T) { data, err := ioutil.ReadFile("../testdata/public_certs.json") if err != nil { t.Fatal(err) } hc, rc := newTestHTTPClient(data) ks := newHTTPKeySource("http://mock.url", hc) if ks.HTTPClient != hc { t.Errorf("HTTPClient = %v; want = %v", ks.HTTPClient, hc) } if err := verifyHTTPKeySource(ks, rc); err != nil { t.Fatal(err) } } func TestHTTPKeySourceEmptyResponse(t *testing.T) { hc, _ := newTestHTTPClient([]byte("")) ks := newHTTPKeySource("http://mock.url", hc) if keys, err := ks.Keys(context.Background()); keys != nil || err == nil { t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err) } } func TestHTTPKeySourceIncorrectResponse(t *testing.T) { hc, _ := newTestHTTPClient([]byte("{\"foo\": \"bar\"}")) ks := newHTTPKeySource("http://mock.url", hc) if keys, err := ks.Keys(context.Background()); keys != nil || err == nil { t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err) } } func TestHTTPKeySourceHTTPError(t *testing.T) { rc := &mockReadCloser{ data: string(""), closeCount: 0, } client := &http.Client{ Transport: &mockHTTPResponse{ Response: http.Response{ Status: "503 Service Unavailable", StatusCode: http.StatusServiceUnavailable, Body: rc, }, Err: nil, }, } ks := newHTTPKeySource("http://mock.url", client) if keys, err := ks.Keys(context.Background()); keys != nil || err == nil { t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err) } } func TestHTTPKeySourceTransportError(t *testing.T) { hc := &http.Client{ Transport: &mockHTTPResponse{ Err: errors.New("transport error"), }, } ks := newHTTPKeySource("http://mock.url", hc) if keys, err := ks.Keys(context.Background()); keys != nil || err == nil { t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err) } } func TestFindMaxAge(t *testing.T) { cases := []struct { cc string want int64 }{ {"max-age=100", 100}, {"public, max-age=100", 100}, {"public,max-age=100", 100}, {"public, max-age=100, must-revalidate, no-transform", 100}, {"", 0}, {"max-age 100", 0}, {"max-age: 100", 0}, {"max-age2=100", 0}, {"max-age=foo", 0}, {"private,", 0}, } for _, tc := range cases { resp := &http.Response{ Header: http.Header{"Cache-Control": {tc.cc}}, } age := findMaxAge(resp) if *age != (time.Duration(tc.want) * time.Second) { t.Errorf("findMaxAge(%q) = %v; want = %v", tc.cc, *age, tc.want) } } } func TestParsePublicKeys(t *testing.T) { b, err := ioutil.ReadFile("../testdata/public_certs.json") if err != nil { t.Fatal(err) } keys, err := parsePublicKeys(b) if err != nil { t.Fatal(err) } if len(keys) != 3 { t.Errorf("parsePublicKeys() = %d; want = %d", len(keys), 3) } } func TestParsePublicKeysError(t *testing.T) { cases := []string{ "", "not-json", } for _, tc := range cases { if keys, err := parsePublicKeys([]byte(tc)); keys != nil || err == nil { t.Errorf("parsePublicKeys(%q) = (%v, %v); want = (nil, err)", tc, keys, err) } } } type mockHTTPResponse struct { Response http.Response Err error } func (m *mockHTTPResponse) RoundTrip(*http.Request) (*http.Response, error) { return &m.Response, m.Err } type mockReadCloser struct { data string index int64 closeCount int } func newTestHTTPClient(data []byte) (*http.Client, *mockReadCloser) { rc := &mockReadCloser{ data: string(data), closeCount: 0, } client := &http.Client{ Transport: &mockHTTPResponse{ Response: http.Response{ Status: "200 OK", StatusCode: http.StatusOK, Header: http.Header{ "Cache-Control": {"public, max-age=100"}, }, Body: rc, }, Err: nil, }, } return client, rc } func (r *mockReadCloser) Read(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } if r.index >= int64(len(r.data)) { return 0, io.EOF } n = copy(p, r.data[r.index:]) r.index += int64(n) return } func (r *mockReadCloser) Close() error { r.closeCount++ r.index = 0 return nil } func verifyHTTPKeySource(ks *httpKeySource, rc *mockReadCloser) error { mc := &internal.MockClock{Timestamp: time.Unix(0, 0)} ks.Clock = mc exp := time.Unix(100, 0) for i := 0; i <= 100; i++ { keys, err := ks.Keys(context.Background()) if err != nil { return err } if len(keys) != 3 { return fmt.Errorf("Keys: %d; want: 3", len(keys)) } else if rc.closeCount != 1 { return fmt.Errorf("HTTP calls: %d; want: 1", rc.closeCount) } else if ks.ExpiryTime != exp { return fmt.Errorf("Expiry: %v; want: %v", ks.ExpiryTime, exp) } mc.Timestamp = mc.Timestamp.Add(time.Second) } mc.Timestamp = time.Unix(101, 0) keys, err := ks.Keys(context.Background()) if err != nil { return err } if len(keys) != 3 { return fmt.Errorf("Keys: %d; want: 3", len(keys)) } else if rc.closeCount != 2 { return fmt.Errorf("HTTP calls: %d; want: 2", rc.closeCount) } return nil } golang-google-firebase-go-4.18.0/auth/user_mgt.go000066400000000000000000001346641505612111400216350ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "regexp" "strconv" "strings" "time" "firebase.google.com/go/v4/internal" ) const ( maxLenPayloadCC = 1000 defaultProviderID = "firebase" idToolkitV1Endpoint = "https://identitytoolkit.googleapis.com/v1" // Maximum number of users allowed to batch get at a time. maxGetAccountsBatchSize = 100 // Maximum number of users allowed to batch delete at a time. maxDeleteAccountsBatchSize = 1000 createUserMethod = "createUser" updateUserMethod = "updateUser" phoneMultiFactorID = "phone" totpMultiFactorID = "totp" ) // 'REDACTED', encoded as a base64 string. var b64Redacted = base64.StdEncoding.EncodeToString([]byte("REDACTED")) // UserInfo is a collection of standard profile information for a user. type UserInfo struct { DisplayName string `json:"displayName,omitempty"` Email string `json:"email,omitempty"` PhoneNumber string `json:"phoneNumber,omitempty"` PhotoURL string `json:"photoUrl,omitempty"` // In the ProviderUserInfo[] ProviderID can be a short domain name (e.g. google.com), // or the identity of an OpenID identity provider. // In UserRecord.UserInfo it will return the constant string "firebase". ProviderID string `json:"providerId,omitempty"` UID string `json:"rawId,omitempty"` } // multiFactorInfoResponse describes the `mfaInfo` of the user record API response type multiFactorInfoResponse struct { MFAEnrollmentID string `json:"mfaEnrollmentId,omitempty"` DisplayName string `json:"displayName,omitempty"` PhoneInfo string `json:"phoneInfo,omitempty"` TOTPInfo *TOTPInfo `json:"totpInfo,omitempty"` EnrolledAt string `json:"enrolledAt,omitempty"` } // TOTPInfo describes a user enrolled second TOTP factor. type TOTPInfo struct{} // PhoneMultiFactorInfo describes a user enrolled in SMS second factor. type PhoneMultiFactorInfo struct { PhoneNumber string } // TOTPMultiFactorInfo describes a user enrolled in TOTP second factor. type TOTPMultiFactorInfo struct{} type multiFactorEnrollments struct { Enrollments []*multiFactorInfoResponse `json:"enrollments"` } // MultiFactorInfo describes a user enrolled second phone factor. type MultiFactorInfo struct { UID string DisplayName string EnrollmentTimestamp int64 FactorID string PhoneNumber string // Deprecated: Use PhoneMultiFactorInfo instead Phone *PhoneMultiFactorInfo TOTP *TOTPMultiFactorInfo } // MultiFactorSettings describes the multi-factor related user settings. type MultiFactorSettings struct { EnrolledFactors []*MultiFactorInfo } // UserMetadata contains additional metadata associated with a user account. // Timestamps are in milliseconds since epoch. type UserMetadata struct { CreationTimestamp int64 LastLogInTimestamp int64 // The time at which the user was last active (ID token refreshed), or 0 if // the user was never active. LastRefreshTimestamp int64 } // UserRecord contains metadata associated with a Firebase user account. type UserRecord struct { *UserInfo CustomClaims map[string]interface{} Disabled bool EmailVerified bool ProviderUserInfo []*UserInfo TokensValidAfterMillis int64 // milliseconds since epoch. UserMetadata *UserMetadata TenantID string MultiFactor *MultiFactorSettings } // UserToCreate is the parameter struct for the CreateUser function. type UserToCreate struct { params map[string]interface{} } // Disabled setter. func (u *UserToCreate) Disabled(disabled bool) *UserToCreate { return u.set("disabled", disabled) } // DisplayName setter. func (u *UserToCreate) DisplayName(name string) *UserToCreate { return u.set("displayName", name) } // Email setter. func (u *UserToCreate) Email(email string) *UserToCreate { return u.set("email", email) } // EmailVerified setter. func (u *UserToCreate) EmailVerified(verified bool) *UserToCreate { return u.set("emailVerified", verified) } // Password setter. func (u *UserToCreate) Password(pw string) *UserToCreate { return u.set("password", pw) } // PhoneNumber setter. func (u *UserToCreate) PhoneNumber(phone string) *UserToCreate { return u.set("phoneNumber", phone) } // PhotoURL setter. func (u *UserToCreate) PhotoURL(url string) *UserToCreate { return u.set("photoUrl", url) } // UID setter. func (u *UserToCreate) UID(uid string) *UserToCreate { return u.set("localId", uid) } // MFASettings setter. func (u *UserToCreate) MFASettings(mfaSettings MultiFactorSettings) *UserToCreate { return u.set("mfaSettings", mfaSettings) } func (u *UserToCreate) set(key string, value interface{}) *UserToCreate { if u.params == nil { u.params = make(map[string]interface{}) } u.params[key] = value return u } // Converts a client format second factor object to server format. func convertMultiFactorInfoToServerFormat(mfaInfo MultiFactorInfo) (multiFactorInfoResponse, error) { authFactorInfo := multiFactorInfoResponse{DisplayName: mfaInfo.DisplayName} if mfaInfo.EnrollmentTimestamp != 0 { authFactorInfo.EnrolledAt = time.Unix(mfaInfo.EnrollmentTimestamp, 0).Format("2006-01-02T15:04:05Z07:00Z") } if mfaInfo.UID != "" { authFactorInfo.MFAEnrollmentID = mfaInfo.UID } switch mfaInfo.FactorID { case phoneMultiFactorID: authFactorInfo.PhoneInfo = mfaInfo.Phone.PhoneNumber case totpMultiFactorID: authFactorInfo.TOTPInfo = (*TOTPInfo)(mfaInfo.TOTP) default: out, _ := json.Marshal(mfaInfo) return multiFactorInfoResponse{}, fmt.Errorf("unsupported second factor %s provided", string(out)) } return authFactorInfo, nil } func (u *UserToCreate) validatedRequest() (map[string]interface{}, error) { req := make(map[string]interface{}) for k, v := range u.params { if k == "mfaSettings" { mfaInfo, err := validateAndFormatMfaSettings(v.(MultiFactorSettings), createUserMethod) if err != nil { return nil, err } req["mfaInfo"] = mfaInfo } else { req[k] = v } } if uid, ok := req["localId"]; ok { if err := validateUID(uid.(string)); err != nil { return nil, err } } if name, ok := req["displayName"]; ok { if err := validateDisplayName(name.(string)); err != nil { return nil, err } } if email, ok := req["email"]; ok { if err := validateEmail(email.(string)); err != nil { return nil, err } } if phone, ok := req["phoneNumber"]; ok { if err := validatePhone(phone.(string)); err != nil { return nil, err } } if url, ok := req["photoUrl"]; ok { if err := validatePhotoURL(url.(string)); err != nil { return nil, err } } if pw, ok := req["password"]; ok { if err := validatePassword(pw.(string)); err != nil { return nil, err } } return req, nil } // UserToUpdate is the parameter struct for the UpdateUser function. type UserToUpdate struct { params map[string]interface{} } // CustomClaims setter. func (u *UserToUpdate) CustomClaims(claims map[string]interface{}) *UserToUpdate { return u.set("customClaims", claims) } // Disabled setter. func (u *UserToUpdate) Disabled(disabled bool) *UserToUpdate { return u.set("disableUser", disabled) } // DisplayName setter. Set to empty string to remove the display name from the user account. func (u *UserToUpdate) DisplayName(name string) *UserToUpdate { return u.set("displayName", name) } // Email setter. func (u *UserToUpdate) Email(email string) *UserToUpdate { return u.set("email", email) } // EmailVerified setter. func (u *UserToUpdate) EmailVerified(verified bool) *UserToUpdate { return u.set("emailVerified", verified) } // Password setter. func (u *UserToUpdate) Password(pw string) *UserToUpdate { return u.set("password", pw) } // PhoneNumber setter. Set to empty string to remove the phone number and the corresponding auth provider // from the user account. func (u *UserToUpdate) PhoneNumber(phone string) *UserToUpdate { return u.set("phoneNumber", phone) } // PhotoURL setter. Set to empty string to remove the photo URL from the user account. func (u *UserToUpdate) PhotoURL(url string) *UserToUpdate { return u.set("photoUrl", url) } // MFASettings setter. func (u *UserToUpdate) MFASettings(mfaSettings MultiFactorSettings) *UserToUpdate { return u.set("mfaSettings", mfaSettings) } // ProviderToLink links this user to the specified provider. // // Linking a provider to an existing user account does not invalidate the // refresh token of that account. In other words, the existing account would // continue to be able to access resources, despite not having used the newly // linked provider to log in. If you wish to force the user to authenticate // with this new provider, you need to (a) revoke their refresh token (see // https://firebase.google.com/docs/auth/admin/manage-sessions#revoke_refresh_tokens), // and (b) ensure no other authentication methods are present on this account. func (u *UserToUpdate) ProviderToLink(userProvider *UserProvider) *UserToUpdate { return u.set("linkProviderUserInfo", userProvider) } // ProvidersToDelete unlinks this user from the specified providers. func (u *UserToUpdate) ProvidersToDelete(providerIds []string) *UserToUpdate { // skip setting the value to empty if it's already empty. if len(providerIds) == 0 { if u.params == nil { return u } if _, ok := u.params["providersToDelete"]; !ok { return u } } return u.set("providersToDelete", providerIds) } // revokeRefreshTokens revokes all refresh tokens for a user by setting the validSince property // to the present in epoch seconds. func (u *UserToUpdate) revokeRefreshTokens() *UserToUpdate { return u.set("validSince", strconv.FormatInt(time.Now().Unix(), 10)) } func (u *UserToUpdate) set(key string, value interface{}) *UserToUpdate { if u.params == nil { u.params = make(map[string]interface{}) } u.params[key] = value return u } func (u *UserToUpdate) validatedRequest() (map[string]interface{}, error) { if len(u.params) == 0 { // update without any parameters is never allowed return nil, fmt.Errorf("update parameters must not be nil or empty") } req := make(map[string]interface{}) for k, v := range u.params { if k == "mfaSettings" { mfaInfo, err := validateAndFormatMfaSettings(v.(MultiFactorSettings), updateUserMethod) if err != nil { return nil, err } // Request body ref: https://cloud.google.com/identity-platform/docs/reference/rest/v1/accounts/update req["mfa"] = multiFactorEnrollments{mfaInfo} } else { req[k] = v } } if email, ok := req["email"]; ok { if err := validateEmail(email.(string)); err != nil { return nil, err } } handleDeletion := func(key, deleteKey, deleteVal string) { var deleteList []string list, ok := req[deleteKey] if ok { deleteList = list.([]string) } req[deleteKey] = append(deleteList, deleteVal) delete(req, key) } if name, ok := req["displayName"]; ok { if name == "" { handleDeletion("displayName", "deleteAttribute", "DISPLAY_NAME") } else if err := validateDisplayName(name.(string)); err != nil { return nil, err } } if url, ok := req["photoUrl"]; ok { if url == "" { handleDeletion("photoUrl", "deleteAttribute", "PHOTO_URL") } else if err := validatePhotoURL(url.(string)); err != nil { return nil, err } } if phone, ok := req["phoneNumber"]; ok { if phone == "" { handleDeletion("phoneNumber", "deleteProvider", "phone") } else if err := validatePhone(phone.(string)); err != nil { return nil, err } } if claims, ok := req["customClaims"]; ok { cc, err := marshalCustomClaims(claims.(map[string]interface{})) if err != nil { return nil, err } req["customAttributes"] = cc delete(req, "customClaims") } if pw, ok := req["password"]; ok { if err := validatePassword(pw.(string)); err != nil { return nil, err } } if linkProviderUserInfo, ok := req["linkProviderUserInfo"]; ok { userProvider := linkProviderUserInfo.(*UserProvider) if err := validateProviderUserInfo(userProvider); err != nil { return nil, err } // Although we don't really advertise it, we want to also handle linking of // non-federated idps with this call. So if we detect one of them, we'll // adjust the properties parameter appropriately. This *does* imply that a // conflict could arise, e.g. if the user provides a phoneNumber property, // but also provides a providerToLink with a 'phone' provider id. In that // case, we'll return an error. if userProvider.ProviderID == "email" { if _, ok := req["email"]; ok { // We could relax this to only return an error if the email addrs don't // match. But for now, we'll be extra picky. return nil, errors.New( "both UserToUpdate.Email and UserToUpdate.ProviderToLink.ProviderID='email' " + "were set; to link to the email/password provider, only specify the " + "UserToUpdate.Email field") } req["email"] = userProvider.UID delete(req, "linkProviderUserInfo") } else if userProvider.ProviderID == "phone" { if _, ok := req["phoneNumber"]; ok { // We could relax this to only return an error if the phone numbers don't // match. But for now, we'll be extra picky. return nil, errors.New( "both UserToUpdate.PhoneNumber and UserToUpdate.ProviderToLink.ProviderID='phone' " + "were set; to link to the phone provider, only specify the " + "UserToUpdate.PhoneNumber field") } req["phoneNumber"] = userProvider.UID delete(req, "linkProviderUserInfo") } } if providersToDelete, ok := req["providersToDelete"]; ok { var deleteProvider []string list, ok := req["deleteProvider"] if ok { deleteProvider = list.([]string) } for _, providerToDelete := range providersToDelete.([]string) { if providerToDelete == "" { return nil, errors.New("providersToDelete must not include empty strings") } // If we've been told to unlink the phone provider both via setting // phoneNumber to "" *and* by setting providersToDelete to include // 'phone', then we'll reject that. Though it might also be reasonable to // relax this restriction and just unlink it. if providerToDelete == "phone" { for _, prov := range deleteProvider { if prov == "phone" { return nil, errors.New("both UserToUpdate.PhoneNumber='' and " + "UserToUpdate.ProvidersToDelete=['phone'] were set; to unlink from a " + "phone provider, only specify the UserToUpdate.PhoneNumber='' field") } } } deleteProvider = append(deleteProvider, providerToDelete) } req["deleteProvider"] = deleteProvider delete(req, "providersToDelete") } return req, nil } func marshalCustomClaims(claims map[string]interface{}) (string, error) { for _, key := range reservedClaims { if _, ok := claims[key]; ok { return "", fmt.Errorf("claim %q is reserved and must not be set", key) } } b, err := json.Marshal(claims) if err != nil { return "", fmt.Errorf("custom claims marshaling error: %v", err) } s := string(b) if s == "null" { s = "{}" // claims map has been explicitly set to nil for deletion. } if len(s) > maxLenPayloadCC { return "", fmt.Errorf("serialized custom claims must not exceed %d characters", maxLenPayloadCC) } return s, nil } // Error handlers. const ( // Backend-generated error codes configurationNotFound = "CONFIGURATION_NOT_FOUND" emailAlreadyExists = "EMAIL_ALREADY_EXISTS" emailNotFound = "EMAIL_NOT_FOUND" invalidDynamicLinkDomain = "INVALID_DYNAMIC_LINK_DOMAIN" invalidHostingLinkDomain = "INVALID_HOSTING_LINK_DOMAIN" phoneNumberAlreadyExists = "PHONE_NUMBER_ALREADY_EXISTS" tenantNotFound = "TENANT_NOT_FOUND" uidAlreadyExists = "UID_ALREADY_EXISTS" unauthorizedContinueURI = "UNAUTHORIZED_CONTINUE_URI" userNotFound = "USER_NOT_FOUND" ) // IsConfigurationNotFound checks if the given error was due to a non-existing IdP configuration. func IsConfigurationNotFound(err error) bool { return hasAuthErrorCode(err, configurationNotFound) } // IsEmailAlreadyExists checks if the given error was due to a duplicate email. func IsEmailAlreadyExists(err error) bool { return hasAuthErrorCode(err, emailAlreadyExists) } // IsEmailNotFound checks if the given error was due to the user record corresponding to the email not being found. func IsEmailNotFound(err error) bool { return hasAuthErrorCode(err, emailNotFound) } // IsInsufficientPermission checks if the given error was due to insufficient permissions. // // Deprecated: Always returns false. func IsInsufficientPermission(err error) bool { return false } // IsInvalidDynamicLinkDomain checks if the given error was due to an invalid dynamic link domain. func IsInvalidDynamicLinkDomain(err error) bool { return hasAuthErrorCode(err, invalidDynamicLinkDomain) } // IsInvalidHostingLinkDomain checks if the given error was due to an invalid hosting link domain. func IsInvalidHostingLinkDomain(err error) bool { return hasAuthErrorCode(err, invalidHostingLinkDomain) } // IsInvalidEmail checks if the given error was due to an invalid email. // // Deprecated: Always returns false. func IsInvalidEmail(err error) bool { return false } // IsPhoneNumberAlreadyExists checks if the given error was due to a duplicate phone number. func IsPhoneNumberAlreadyExists(err error) bool { return hasAuthErrorCode(err, phoneNumberAlreadyExists) } // IsProjectNotFound checks if the given error was due to a non-existing project. // // Deprecated: Always returns false. func IsProjectNotFound(err error) bool { return false } // IsTenantNotFound checks if the given error was due to a non-existing tenant ID. func IsTenantNotFound(err error) bool { return hasAuthErrorCode(err, tenantNotFound) } // IsUIDAlreadyExists checks if the given error was due to a duplicate uid. func IsUIDAlreadyExists(err error) bool { return hasAuthErrorCode(err, uidAlreadyExists) } // IsUnauthorizedContinueURI checks if the given error was due to an unauthorized continue URI domain. func IsUnauthorizedContinueURI(err error) bool { return hasAuthErrorCode(err, unauthorizedContinueURI) } // IsUnknown checks if the given error was due to a unknown server error. // // Deprecated: Always returns false. func IsUnknown(err error) bool { return false } // IsUserNotFound checks if the given error was due to non-existing user. func IsUserNotFound(err error) bool { return hasAuthErrorCode(err, userNotFound) } // Validators. func validateDisplayName(val string) error { if val == "" { return fmt.Errorf("display name must be a non-empty string") } return nil } func validatePhotoURL(val string) error { if val == "" { return fmt.Errorf("photo url must be a non-empty string") } return nil } func validateEmail(email string) error { if email == "" { return fmt.Errorf("email must be a non-empty string") } if parts := strings.Split(email, "@"); len(parts) != 2 || parts[0] == "" || parts[1] == "" { return fmt.Errorf("malformed email string: %q", email) } return nil } func validatePassword(val string) error { if len(val) < 6 { return fmt.Errorf("password must be a string at least 6 characters long") } return nil } func validateUID(uid string) error { if uid == "" { return fmt.Errorf("uid must be a non-empty string") } if len(uid) > 128 { return fmt.Errorf("uid string must not be longer than 128 characters") } return nil } func validatePhone(phone string) error { if phone == "" { return fmt.Errorf("phone number must be a non-empty string") } if !regexp.MustCompile(`\+.*[0-9A-Za-z]`).MatchString(phone) { return fmt.Errorf("phone number must be a valid, E.164 compliant identifier") } return nil } func validateProviderUserInfo(p *UserProvider) error { if p.UID == "" { return fmt.Errorf("user provider must specify a uid") } if p.ProviderID == "" { return fmt.Errorf("user provider must specify a provider ID") } return nil } func validateProvider(providerID string, providerUID string) error { if providerID == "" { return fmt.Errorf("providerID must be a non-empty string") } else if providerUID == "" { return fmt.Errorf("providerUID must be a non-empty string") } return nil } func validateAndFormatMfaSettings(mfaSettings MultiFactorSettings, methodType string) ([]*multiFactorInfoResponse, error) { var mfaInfo []*multiFactorInfoResponse for _, multiFactorInfo := range mfaSettings.EnrolledFactors { if multiFactorInfo.FactorID == "" { return nil, fmt.Errorf("no factor id specified") } switch methodType { case createUserMethod: // Enrollment time and uid are not allowed for signupNewUser endpoint. They will automatically be provisioned server side. if multiFactorInfo.EnrollmentTimestamp != 0 { return nil, fmt.Errorf("\"EnrollmentTimeStamp\" is not supported when adding second factors via \"createUser()\"") } if multiFactorInfo.UID != "" { return nil, fmt.Errorf("\"uid\" is not supported when adding second factors via \"createUser()\"") } case updateUserMethod: default: return nil, fmt.Errorf("unsupported methodType: %s", methodType) } if err := validateDisplayName(multiFactorInfo.DisplayName); err != nil { return nil, fmt.Errorf("the second factor \"displayName\" for \"%s\" must be a valid non-empty string", multiFactorInfo.DisplayName) } if multiFactorInfo.FactorID == phoneMultiFactorID { if multiFactorInfo.Phone != nil { // If PhoneMultiFactorInfo is provided, validate its PhoneNumber field if err := validatePhone(multiFactorInfo.Phone.PhoneNumber); err != nil { return nil, fmt.Errorf("the second factor \"phoneNumber\" for \"%s\" must be a non-empty E.164 standard compliant identifier string", multiFactorInfo.Phone.PhoneNumber) } // No need for the else here since we are returning from the function } else if multiFactorInfo.PhoneNumber != "" { // PhoneMultiFactorInfo is nil, check the deprecated PhoneNumber field if err := validatePhone(multiFactorInfo.PhoneNumber); err != nil { return nil, fmt.Errorf("the second factor \"phoneNumber\" for \"%s\" must be a non-empty E.164 standard compliant identifier string", multiFactorInfo.PhoneNumber) } // The PhoneNumber field is deprecated, set it in PhoneMultiFactorInfo and inform about the deprecation. multiFactorInfo.Phone = &PhoneMultiFactorInfo{ PhoneNumber: multiFactorInfo.PhoneNumber, } } else { // Both PhoneMultiFactorInfo and deprecated PhoneNumber are missing. return nil, fmt.Errorf("\"PhoneMultiFactorInfo\" must be defined") } } obj, err := convertMultiFactorInfoToServerFormat(*multiFactorInfo) if err != nil { return nil, err } mfaInfo = append(mfaInfo, &obj) } return mfaInfo, nil } // End of validators // GetUser gets the user data corresponding to the specified user ID. func (c *baseClient) GetUser(ctx context.Context, uid string) (*UserRecord, error) { return c.getUser(ctx, &userQuery{ field: "localId", value: uid, label: "uid", }) } // GetUserByEmail gets the user data corresponding to the specified email. func (c *baseClient) GetUserByEmail(ctx context.Context, email string) (*UserRecord, error) { if err := validateEmail(email); err != nil { return nil, err } return c.getUser(ctx, &userQuery{ field: "email", value: email, }) } // GetUserByPhoneNumber gets the user data corresponding to the specified user phone number. func (c *baseClient) GetUserByPhoneNumber(ctx context.Context, phone string) (*UserRecord, error) { if err := validatePhone(phone); err != nil { return nil, err } return c.getUser(ctx, &userQuery{ field: "phoneNumber", value: phone, label: "phone number", }) } // GetUserByProviderID is an alias for GetUserByProviderUID. // // Deprecated: Use GetUserByProviderUID instead. func (c *baseClient) GetUserByProviderID(ctx context.Context, providerID string, providerUID string) (*UserRecord, error) { return c.GetUserByProviderUID(ctx, providerID, providerUID) } // GetUserByProviderUID gets the user data for the user corresponding to a given provider ID. // // See // https://firebase.google.com/docs/auth/admin/manage-users#retrieve_user_data // for code samples and detailed documentation. // // `providerID` indicates the provider, such as 'google.com' for the Google provider. // `providerUID` is the user identifier for the given provider. func (c *baseClient) GetUserByProviderUID(ctx context.Context, providerID string, providerUID string) (*UserRecord, error) { // Although we don't really advertise it, we want to also handle non-federated // IDPs with this call. So if we detect one of them, we'll reroute this // request appropriately. if providerID == "phone" { return c.GetUserByPhoneNumber(ctx, providerUID) } else if providerID == "email" { return c.GetUserByEmail(ctx, providerUID) } if err := validateProvider(providerID, providerUID); err != nil { return nil, err } getUsersResult, err := c.GetUsers(ctx, []UserIdentifier{&ProviderIdentifier{providerID, providerUID}}) if err != nil { return nil, err } if len(getUsersResult.Users) == 0 { return nil, &internal.FirebaseError{ ErrorCode: internal.NotFound, String: fmt.Sprintf("cannot find user from providerID: { %s, %s }", providerID, providerUID), Response: nil, Ext: map[string]interface{}{ authErrorCode: userNotFound, }, } } return getUsersResult.Users[0], nil } type userQuery struct { field string value string label string } func (q *userQuery) description() string { label := q.label if label == "" { label = q.field } return fmt.Sprintf("%s: %q", label, q.value) } func (q *userQuery) build() map[string]interface{} { return map[string]interface{}{ q.field: []string{q.value}, } } type getAccountInfoResponse struct { Users []*userQueryResponse `json:"users"` } func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) { var parsed getAccountInfoResponse resp, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed) if err != nil { return nil, err } if len(parsed.Users) == 0 { return nil, &internal.FirebaseError{ ErrorCode: internal.NotFound, String: fmt.Sprintf("no user exists with the %s", query.description()), Response: resp.LowLevelResponse(), Ext: map[string]interface{}{ authErrorCode: userNotFound, }, } } return parsed.Users[0].makeUserRecord() } // A UserIdentifier identifies a user to be looked up. type UserIdentifier interface { matches(ur *UserRecord) bool populate(req *getAccountInfoRequest) } // A UIDIdentifier is used for looking up an account by uid. // // See GetUsers function. type UIDIdentifier struct { UID string } func (id UIDIdentifier) matches(ur *UserRecord) bool { return id.UID == ur.UID } func (id UIDIdentifier) populate(req *getAccountInfoRequest) { req.LocalID = append(req.LocalID, id.UID) } // An EmailIdentifier is used for looking up an account by email. // // See GetUsers function. type EmailIdentifier struct { Email string } func (id EmailIdentifier) matches(ur *UserRecord) bool { return id.Email == ur.Email } func (id EmailIdentifier) populate(req *getAccountInfoRequest) { req.Email = append(req.Email, id.Email) } // A PhoneIdentifier is used for looking up an account by phone number. // // See GetUsers function. type PhoneIdentifier struct { PhoneNumber string } func (id PhoneIdentifier) matches(ur *UserRecord) bool { return id.PhoneNumber == ur.PhoneNumber } func (id PhoneIdentifier) populate(req *getAccountInfoRequest) { req.PhoneNumber = append(req.PhoneNumber, id.PhoneNumber) } // A ProviderIdentifier is used for looking up an account by federated provider. // // See GetUsers function. type ProviderIdentifier struct { ProviderID string ProviderUID string } func (id ProviderIdentifier) matches(ur *UserRecord) bool { for _, userInfo := range ur.ProviderUserInfo { if id.ProviderID == userInfo.ProviderID && id.ProviderUID == userInfo.UID { return true } } return false } func (id ProviderIdentifier) populate(req *getAccountInfoRequest) { req.FederatedUserID = append( req.FederatedUserID, federatedUserIdentifier{ProviderID: id.ProviderID, RawID: id.ProviderUID}) } // A GetUsersResult represents the result of the GetUsers() API. type GetUsersResult struct { // Set of UserRecords corresponding to the set of users that were requested. // Only users that were found are listed here. The result set is unordered. Users []*UserRecord // Set of UserIdentifiers that were requested, but not found. NotFound []UserIdentifier } type federatedUserIdentifier struct { ProviderID string `json:"providerId,omitempty"` RawID string `json:"rawId,omitempty"` } type getAccountInfoRequest struct { LocalID []string `json:"localId,omitempty"` Email []string `json:"email,omitempty"` PhoneNumber []string `json:"phoneNumber,omitempty"` FederatedUserID []federatedUserIdentifier `json:"federatedUserId,omitempty"` } func (req *getAccountInfoRequest) validate() error { for i := range req.LocalID { if err := validateUID(req.LocalID[i]); err != nil { return err } } for i := range req.Email { if err := validateEmail(req.Email[i]); err != nil { return err } } for i := range req.PhoneNumber { if err := validatePhone(req.PhoneNumber[i]); err != nil { return err } } for i := range req.FederatedUserID { id := &req.FederatedUserID[i] if err := validateProvider(id.ProviderID, id.RawID); err != nil { return err } } return nil } func isUserFound(id UserIdentifier, urs [](*UserRecord)) bool { for i := range urs { if id.matches(urs[i]) { return true } } return false } // GetUsers returns the user data corresponding to the specified identifiers. // // There are no ordering guarantees; in particular, the nth entry in the users // result list is not guaranteed to correspond to the nth entry in the input // parameters list. // // A maximum of 100 identifiers may be supplied. If more than 100 // identifiers are supplied, this method returns an error. // // Returns the corresponding user records. An error is returned instead if any // of the identifiers are invalid or if more than 100 identifiers are // specified. func (c *baseClient) GetUsers( ctx context.Context, identifiers []UserIdentifier, ) (*GetUsersResult, error) { if len(identifiers) == 0 { return &GetUsersResult{[](*UserRecord){}, [](UserIdentifier){}}, nil } else if len(identifiers) > maxGetAccountsBatchSize { return nil, fmt.Errorf( "`identifiers` parameter must have <= %d entries", maxGetAccountsBatchSize) } var request getAccountInfoRequest for i := range identifiers { identifiers[i].populate(&request) } if err := request.validate(); err != nil { return nil, err } var parsed getAccountInfoResponse if _, err := c.post(ctx, "/accounts:lookup", request, &parsed); err != nil { return nil, err } var userRecords [](*UserRecord) for _, user := range parsed.Users { userRecord, err := user.makeUserRecord() if err != nil { return nil, err } userRecords = append(userRecords, userRecord) } var notFound []UserIdentifier for i := range identifiers { if !isUserFound(identifiers[i], userRecords) { notFound = append(notFound, identifiers[i]) } } return &GetUsersResult{userRecords, notFound}, nil } type userQueryResponse struct { UID string `json:"localId,omitempty"` DisplayName string `json:"displayName,omitempty"` Email string `json:"email,omitempty"` PhoneNumber string `json:"phoneNumber,omitempty"` PhotoURL string `json:"photoUrl,omitempty"` CreationTimestamp int64 `json:"createdAt,string,omitempty"` LastLogInTimestamp int64 `json:"lastLoginAt,string,omitempty"` LastRefreshAt string `json:"lastRefreshAt,omitempty"` ProviderID string `json:"providerId,omitempty"` CustomAttributes string `json:"customAttributes,omitempty"` Disabled bool `json:"disabled,omitempty"` EmailVerified bool `json:"emailVerified,omitempty"` ProviderUserInfo []*UserInfo `json:"providerUserInfo,omitempty"` PasswordHash string `json:"passwordHash,omitempty"` PasswordSalt string `json:"salt,omitempty"` TenantID string `json:"tenantId,omitempty"` ValidSinceSeconds int64 `json:"validSince,string,omitempty"` MFAInfo []*multiFactorInfoResponse `json:"mfaInfo,omitempty"` } func (r *userQueryResponse) makeUserRecord() (*UserRecord, error) { exported, err := r.makeExportedUserRecord() if err != nil { return nil, err } return exported.UserRecord, nil } func (r *userQueryResponse) makeExportedUserRecord() (*ExportedUserRecord, error) { var customClaims map[string]interface{} if r.CustomAttributes != "" { if err := json.Unmarshal([]byte(r.CustomAttributes), &customClaims); err != nil { return nil, err } if len(customClaims) == 0 { customClaims = nil } } // If the password hash is redacted (probably due to missing permissions) // then clear it out, similar to how the salt is returned. (Otherwise, it // *looks* like a b64-encoded hash is present, which is confusing.) hash := r.PasswordHash if hash == b64Redacted { hash = "" } var lastRefreshTimestamp int64 if r.LastRefreshAt != "" { t, err := time.Parse(time.RFC3339, r.LastRefreshAt) if err != nil { return nil, err } lastRefreshTimestamp = t.Unix() * 1000 } // Map the MFA info to a slice of enrolled factors. Currently there is only // support for PhoneMultiFactorInfo. var enrolledFactors []*MultiFactorInfo for _, factor := range r.MFAInfo { var enrollmentTimestamp int64 if factor.EnrolledAt != "" { t, err := time.Parse(time.RFC3339, factor.EnrolledAt) if err != nil { return nil, err } enrollmentTimestamp = t.Unix() * 1000 } if factor.PhoneInfo != "" { enrolledFactors = append(enrolledFactors, &MultiFactorInfo{ UID: factor.MFAEnrollmentID, DisplayName: factor.DisplayName, EnrollmentTimestamp: enrollmentTimestamp, FactorID: phoneMultiFactorID, PhoneNumber: factor.PhoneInfo, Phone: &PhoneMultiFactorInfo{ PhoneNumber: factor.PhoneInfo, }, }) } else if factor.TOTPInfo != nil { enrolledFactors = append(enrolledFactors, &MultiFactorInfo{ UID: factor.MFAEnrollmentID, DisplayName: factor.DisplayName, EnrollmentTimestamp: enrollmentTimestamp, FactorID: totpMultiFactorID, TOTP: &TOTPMultiFactorInfo{}, }) } else { return nil, fmt.Errorf("unsupported multi-factor auth response: %#v", factor) } } return &ExportedUserRecord{ UserRecord: &UserRecord{ UserInfo: &UserInfo{ DisplayName: r.DisplayName, Email: r.Email, PhoneNumber: r.PhoneNumber, PhotoURL: r.PhotoURL, UID: r.UID, ProviderID: defaultProviderID, }, CustomClaims: customClaims, Disabled: r.Disabled, EmailVerified: r.EmailVerified, ProviderUserInfo: r.ProviderUserInfo, TenantID: r.TenantID, TokensValidAfterMillis: r.ValidSinceSeconds * 1000, UserMetadata: &UserMetadata{ LastLogInTimestamp: r.LastLogInTimestamp, CreationTimestamp: r.CreationTimestamp, LastRefreshTimestamp: lastRefreshTimestamp, }, MultiFactor: &MultiFactorSettings{ EnrolledFactors: enrolledFactors, }, }, PasswordHash: hash, PasswordSalt: r.PasswordSalt, }, nil } // CreateUser creates a new user with the specified properties. func (c *baseClient) CreateUser(ctx context.Context, user *UserToCreate) (*UserRecord, error) { uid, err := c.createUser(ctx, user) if err != nil { return nil, err } return c.GetUser(ctx, uid) } func (c *baseClient) createUser(ctx context.Context, user *UserToCreate) (string, error) { if user == nil { user = &UserToCreate{} } request, err := user.validatedRequest() if err != nil { return "", err } var result struct { UID string `json:"localId"` } _, err = c.post(ctx, "/accounts", request, &result) return result.UID, err } // UpdateUser updates an existing user account with the specified properties. func (c *baseClient) UpdateUser( ctx context.Context, uid string, user *UserToUpdate) (ur *UserRecord, err error) { if err := c.updateUser(ctx, uid, user); err != nil { return nil, err } return c.GetUser(ctx, uid) } // RevokeRefreshTokens revokes all refresh tokens issued to a user. // // RevokeRefreshTokens updates the user's TokensValidAfterMillis to the current UTC second. // It is important that the server on which this is called has its clock set correctly and synchronized. // // While this revokes all sessions for a specified user and disables any new ID tokens for existing sessions // from getting minted, existing ID tokens may remain active until their natural expiration (one hour). // To verify that ID tokens are revoked, use `verifyIdTokenAndCheckRevoked(ctx, idToken)`. func (c *baseClient) RevokeRefreshTokens(ctx context.Context, uid string) error { return c.updateUser(ctx, uid, (&UserToUpdate{}).revokeRefreshTokens()) } // SetCustomUserClaims sets additional claims on an existing user account. // // Custom claims set via this function can be used to define user roles and privilege levels. // These claims propagate to all the devices where the user is already signed in (after token // expiration or when token refresh is forced), and next time the user signs in. The claims // can be accessed via the user's ID token JWT. If a reserved OIDC claim is specified (sub, iat, // iss, etc), an error is thrown. Claims payload must also not be larger then 1000 characters // when serialized into a JSON string. func (c *baseClient) SetCustomUserClaims(ctx context.Context, uid string, customClaims map[string]interface{}) error { if customClaims == nil || len(customClaims) == 0 { customClaims = map[string]interface{}{} } return c.updateUser(ctx, uid, (&UserToUpdate{}).CustomClaims(customClaims)) } func (c *baseClient) updateUser(ctx context.Context, uid string, user *UserToUpdate) error { if err := validateUID(uid); err != nil { return err } if user == nil { return fmt.Errorf("update parameters must not be nil or empty") } request, err := user.validatedRequest() if err != nil { return err } request["localId"] = uid _, err = c.post(ctx, "/accounts:update", request, nil) return err } // DeleteUser deletes the user by the given UID. func (c *baseClient) DeleteUser(ctx context.Context, uid string) error { if err := validateUID(uid); err != nil { return err } payload := map[string]interface{}{ "localId": uid, } _, err := c.post(ctx, "/accounts:delete", payload, nil) return err } // A DeleteUsersResult represents the result of the DeleteUsers() call. type DeleteUsersResult struct { // The number of users that were deleted successfully (possibly zero). Users // that did not exist prior to calling DeleteUsers() are considered to be // successfully deleted. SuccessCount int // The number of users that failed to be deleted (possibly zero). FailureCount int // A list of DeleteUsersErrorInfo instances describing the errors that were // encountered during the deletion. Length of this list is equal to the value // of FailureCount. Errors []*DeleteUsersErrorInfo } // DeleteUsersErrorInfo represents an error encountered while deleting a user // account. // // The Index field corresponds to the index of the failed user in the uids // array that was passed to DeleteUsers(). type DeleteUsersErrorInfo struct { Index int `json:"index,omitEmpty"` Reason string `json:"message,omitEmpty"` } // DeleteUsers deletes the users specified by the given identifiers. // // Deleting a non-existing user won't generate an error. (i.e. this method is // idempotent.) Non-existing users are considered to be successfully // deleted, and are therefore counted in the DeleteUsersResult.SuccessCount // value. // // A maximum of 1000 identifiers may be supplied. If more than 1000 // identifiers are supplied, this method returns an error. // // This API is currently rate limited at the server to 1 QPS. If you exceed // this, you may get a quota exceeded error. Therefore, if you want to delete // more than 1000 users, you may need to add a delay to ensure you don't go // over this limit. // // Returns the total number of successful/failed deletions, as well as the // array of errors that correspond to the failed deletions. An error is // returned if any of the identifiers are invalid or if more than 1000 // identifiers are specified. func (c *baseClient) DeleteUsers(ctx context.Context, uids []string) (*DeleteUsersResult, error) { if len(uids) == 0 { return &DeleteUsersResult{}, nil } else if len(uids) > maxDeleteAccountsBatchSize { return nil, fmt.Errorf( "`uids` parameter must have <= %d entries", maxDeleteAccountsBatchSize) } var payload struct { LocalIds []string `json:"localIds"` Force bool `json:"force"` } payload.Force = true for i := range uids { if err := validateUID(uids[i]); err != nil { return nil, err } payload.LocalIds = append(payload.LocalIds, uids[i]) } type batchDeleteAccountsResponse struct { Errors []*DeleteUsersErrorInfo `json:"errors"` } resp := batchDeleteAccountsResponse{} if _, err := c.post(ctx, "/accounts:batchDelete", payload, &resp); err != nil { return nil, err } result := DeleteUsersResult{ FailureCount: len(resp.Errors), SuccessCount: len(uids) - len(resp.Errors), Errors: resp.Errors, } return &result, nil } // SessionCookie creates a new Firebase session cookie from the given ID token and expiry // duration. The returned JWT can be set as a server-side session cookie with a custom cookie // policy. Expiry duration must be at least 5 minutes but may not exceed 14 days. // // This function is only exposed via [auth.Client] for now, since the tenant-scoped variant // of it is currently not supported. func (c *baseClient) createSessionCookie( ctx context.Context, idToken string, expiresIn time.Duration, ) (string, error) { if idToken == "" { return "", errors.New("id token must not be empty") } if expiresIn < 5*time.Minute || expiresIn > 14*24*time.Hour { return "", errors.New("expiry duration must be between 5 minutes and 14 days") } payload := map[string]interface{}{ "idToken": idToken, "validDuration": int64(expiresIn.Seconds()), } var result struct { SessionCookie string `json:"sessionCookie"` } _, err := c.post(ctx, ":createSessionCookie", payload, &result) return result.SessionCookie, err } func (c *baseClient) post( ctx context.Context, path string, payload, resp interface{}, ) (*internal.Response, error) { url, err := c.makeUserMgtURL(path) if err != nil { return nil, err } req := &internal.Request{ Method: http.MethodPost, URL: url, Body: internal.NewJSONEntity(payload), } return c.httpClient.DoAndUnmarshal(ctx, req, resp) } func (c *baseClient) makeUserMgtURL(path string) (string, error) { if c.projectID == "" { return "", errors.New("project id not available") } var url string if c.tenantID != "" { url = fmt.Sprintf("%s/projects/%s/tenants/%s%s", c.userManagementEndpoint, c.projectID, c.tenantID, path) } else { url = fmt.Sprintf("%s/projects/%s%s", c.userManagementEndpoint, c.projectID, path) } return url, nil } type authError struct { code internal.ErrorCode message string authCode string } var serverError = map[string]*authError{ "CONFIGURATION_NOT_FOUND": { code: internal.NotFound, message: "no IdP configuration corresponding to the provided identifier", authCode: configurationNotFound, }, "DUPLICATE_EMAIL": { code: internal.AlreadyExists, message: "user with the provided email already exists", authCode: emailAlreadyExists, }, "DUPLICATE_LOCAL_ID": { code: internal.AlreadyExists, message: "user with the provided uid already exists", authCode: uidAlreadyExists, }, "EMAIL_EXISTS": { code: internal.AlreadyExists, message: "user with the provided email already exists", authCode: emailAlreadyExists, }, "EMAIL_NOT_FOUND": { code: internal.NotFound, message: "no user record found for the given email", authCode: emailNotFound, }, "INVALID_DYNAMIC_LINK_DOMAIN": { code: internal.InvalidArgument, message: "the provided dynamic link domain is not configured or authorized for the current project", authCode: invalidDynamicLinkDomain, }, "INVALID_HOSTING_LINK_DOMAIN": { code: internal.InvalidArgument, message: "the provided hosting link domain is not configured in Firebase Hosting or is not owned by the current project", authCode: invalidHostingLinkDomain, }, "PHONE_NUMBER_EXISTS": { code: internal.AlreadyExists, message: "user with the provided phone number already exists", authCode: phoneNumberAlreadyExists, }, "TENANT_NOT_FOUND": { code: internal.NotFound, message: "tenant with the specified ID does not exist", authCode: tenantNotFound, }, "UNAUTHORIZED_DOMAIN": { code: internal.InvalidArgument, message: "domain of the continue url is not whitelisted", authCode: unauthorizedContinueURI, }, "USER_NOT_FOUND": { code: internal.NotFound, message: "no user record found for the given identifier", authCode: userNotFound, }, } func handleHTTPError(resp *internal.Response) error { err := internal.NewFirebaseError(resp) code, detail := parseErrorResponse(resp) if authErr, ok := serverError[code]; ok { err.ErrorCode = authErr.code err.Ext[authErrorCode] = authErr.authCode if detail != "" { err.String = fmt.Sprintf("%s: %s", authErr.message, detail) } else { err.String = authErr.message } } return err } func parseErrorResponse(resp *internal.Response) (string, string) { var httpErr struct { Error struct { Message string `json:"message"` } `json:"error"` } // ignore any json parse errors at this level json.Unmarshal(resp.Body, &httpErr) // Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} code, detail := httpErr.Error.Message, "" idx := strings.Index(code, ":") if idx != -1 { detail = strings.TrimSpace(code[idx+1:]) code = code[:idx] } return code, detail } golang-google-firebase-go-4.18.0/auth/user_mgt_test.go000066400000000000000000001775351505612111400227000ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "bytes" "context" "encoding/base64" "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httptest" "reflect" "sort" "strconv" "strings" "testing" "time" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" "google.golang.org/api/iterator" ) var testUser = &UserRecord{ UserInfo: &UserInfo{ UID: "testuser", Email: "testuser@example.com", PhoneNumber: "+1234567890", DisplayName: "Test User", PhotoURL: "http://www.example.com/testuser/photo.png", ProviderID: defaultProviderID, }, Disabled: false, EmailVerified: true, ProviderUserInfo: []*UserInfo{ { ProviderID: "password", DisplayName: "Test User", PhotoURL: "http://www.example.com/testuser/photo.png", Email: "testuser@example.com", UID: "testuid", }, { ProviderID: "phone", PhoneNumber: "+1234567890", UID: "testuid", }, }, TokensValidAfterMillis: 1494364393000, UserMetadata: &UserMetadata{ CreationTimestamp: 1234567890000, LastLogInTimestamp: 1233211232000, }, CustomClaims: map[string]interface{}{"admin": true, "package": "gold"}, TenantID: "testTenant", MultiFactor: &MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { UID: "enrolledPhoneFactor", FactorID: "phone", EnrollmentTimestamp: 1614776780000, Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+1234567890", }, PhoneNumber: "+1234567890", DisplayName: "My MFA Phone", }, { UID: "enrolledTOTPFactor", FactorID: "totp", EnrollmentTimestamp: 1614776780000, TOTP: &TOTPMultiFactorInfo{}, DisplayName: "My MFA TOTP", }, }, }, } var testUserWithoutMFA = &UserRecord{ UserInfo: &UserInfo{ UID: "testusernomfa", Email: "testusernomfa@example.com", PhoneNumber: "+1234567890", DisplayName: "Test User Without MFA", PhotoURL: "http://www.example.com/testusernomfa/photo.png", ProviderID: defaultProviderID, }, Disabled: false, EmailVerified: true, ProviderUserInfo: []*UserInfo{ { ProviderID: "password", DisplayName: "Test User Without MFA", PhotoURL: "http://www.example.com/testusernomfa/photo.png", Email: "testusernomfa@example.com", UID: "testuid", }, { ProviderID: "phone", PhoneNumber: "+1234567890", UID: "testuid", }, }, TokensValidAfterMillis: 1494364393000, UserMetadata: &UserMetadata{ CreationTimestamp: 1234567890000, LastLogInTimestamp: 1233211232000, }, CustomClaims: map[string]interface{}{"admin": true, "package": "gold"}, TenantID: "testTenant", MultiFactor: &MultiFactorSettings{}, } func TestGetUser(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() user, err := s.Client.GetUser(context.Background(), "ignored_id") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user, testUser) { t.Errorf("GetUser() = %#v; want = %#v", user, testUser) } want := `{"localId":["ignored_id"]}` got := string(s.Rbody) if got != want { t.Errorf("GetUser() Req = %v; want = %v", got, want) } wantPath := "/projects/mock-project-id/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUser() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestGetUserByEmail(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() user, err := s.Client.GetUserByEmail(context.Background(), "test@email.com") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user, testUser) { t.Errorf("GetUserByEmail() = %#v; want = %#v", user, testUser) } want := `{"email":["test@email.com"]}` got := string(s.Rbody) if got != want { t.Errorf("GetUserByEmail() Req = %v; want = %v", got, want) } wantPath := "/projects/mock-project-id/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUserByEmail() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestGetUserByPhoneNumber(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() user, err := s.Client.GetUserByPhoneNumber(context.Background(), "+1234567890") if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user, testUser) { t.Errorf("GetUserByPhoneNumber() = %#v; want = %#v", user, testUser) } want := `{"phoneNumber":["+1234567890"]}` got := string(s.Rbody) if got != want { t.Errorf("GetUserByPhoneNumber() Req = %v; want = %v", got, want) } wantPath := "/projects/mock-project-id/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUserByPhoneNumber() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestGetUserByProviderIDNotFound(t *testing.T) { mockUsers := []byte(`{ "users": [] }`) s := echoServer(mockUsers, t) defer s.Close() userRecord, err := s.Client.GetUserByProviderUID(context.Background(), "google.com", "google_uid1") want := "cannot find user from providerID: { google.com, google_uid1 }" if userRecord != nil || err == nil || err.Error() != want || !IsUserNotFound(err) { t.Errorf("GetUserByProviderUID() = (%v, %q); want = (nil, %q)", userRecord, err, want) } } func TestGetUserByProviderId(t *testing.T) { cases := []struct { providerID string providerUID string want string }{ { "google.com", "google_uid1", `{"federatedUserId":[{"providerId":"google.com","rawId":"google_uid1"}]}`, }, { "phone", "+15555550001", `{"phoneNumber":["+15555550001"]}`, }, { "email", "user@example.com", `{"email":["user@example.com"]}`, }, } // The resulting user isn't parsed, so it just needs to exist (even if it's empty). mockUsers := []byte(`{ "users": [{}] }`) s := echoServer(mockUsers, t) defer s.Close() for _, tc := range cases { t.Run(tc.providerID+":"+tc.providerUID, func(t *testing.T) { _, err := s.Client.GetUserByProviderUID(context.Background(), tc.providerID, tc.providerUID) if err != nil { t.Fatalf("GetUserByProviderUID() = %q", err) } got := string(s.Rbody) if got != tc.want { t.Errorf("GetUserByProviderUID() Req = %v; want = %v", got, tc.want) } wantPath := "/projects/mock-project-id/accounts:lookup" if s.Req[0].RequestURI != wantPath { t.Errorf("GetUserByProviderUID() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } }) } } func TestInvalidGetUser(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } user, err := client.GetUser(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUser('') = (%v, %v); want = (nil, error)", user, err) } user, err = client.GetUserByEmail(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUserByEmail('') = (%v, %v); want = (nil, error)", user, err) } user, err = client.GetUserByPhoneNumber(context.Background(), "") if user != nil || err == nil { t.Errorf("GetUserPhoneNumber('') = (%v, %v); want = (nil, error)", user, err) } userRecord, err := client.GetUserByProviderUID(context.Background(), "", "google_uid1") want := "providerID must be a non-empty string" if userRecord != nil || err == nil || err.Error() != want { t.Errorf("GetUserByProviderUID() = (%v, %q); want = (nil, %q)", userRecord, err, want) } userRecord, err = client.GetUserByProviderUID(context.Background(), "google.com", "") want = "providerUID must be a non-empty string" if userRecord != nil || err == nil || err.Error() != want { t.Errorf("GetUserByProviderUID() = (%v, %q); want = (nil, %q)", userRecord, err, want) } } // Checks to see if the users list contain the given uids. Order is ignored. // // Behaviour is undefined if there are duplicate entries in either of the // slices. // // This function is identical to the one in integration/auth/user_mgt_test.go func sameUsers(users [](*UserRecord), uids []string) bool { if len(users) != len(uids) { return false } sort.Slice(users, func(i, j int) bool { return users[i].UID < users[j].UID }) sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) for i := range users { if users[i].UID != uids[i] { return false } } return true } func TestGetUsersExceeds100(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } var identifiers [101]UserIdentifier for i := 0; i < 101; i++ { identifiers[i] = &UIDIdentifier{UID: fmt.Sprintf("id%d", i)} } getUsersResult, err := client.GetUsers(context.Background(), identifiers[:]) want := "`identifiers` parameter must have <= 100 entries" if getUsersResult != nil || err == nil || err.Error() != want { t.Errorf( "GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) } } func TestGetUsersEmpty(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } getUsersResult, err := client.GetUsers(context.Background(), [](UserIdentifier){}) if getUsersResult == nil || err != nil { t.Fatalf("GetUsers([]) = %q", err) } if len(getUsersResult.Users) != 0 { t.Errorf("len(GetUsers([]).Users) = %d; want 0", len(getUsersResult.Users)) } if len(getUsersResult.NotFound) != 0 { t.Errorf("len(GetUsers([]).NotFound) = %d; want 0", len(getUsersResult.NotFound)) } } func TestGetUsersAllNonExisting(t *testing.T) { resp := `{ "kind" : "identitytoolkit#GetAccountInfoResponse", "users" : [] }` s := echoServer([]byte(resp), t) defer s.Close() notFoundIds := []UserIdentifier{&UIDIdentifier{"id that doesnt exist"}} getUsersResult, err := s.Client.GetUsers(context.Background(), notFoundIds) if err != nil { t.Fatalf("GetUsers() = %q", err) } if len(getUsersResult.Users) != 0 { t.Errorf( "len(GetUsers().Users) = %d; want 0", len(getUsersResult.Users)) } if len(getUsersResult.NotFound) != len(notFoundIds) { t.Errorf("len(GetUsers()).NotFound) = %d; want %d", len(getUsersResult.NotFound), len(notFoundIds)) } else { for i := range notFoundIds { if getUsersResult.NotFound[i] != notFoundIds[i] { t.Errorf("GetUsers().NotFound[%d] = %v; want %v", i, getUsersResult.NotFound[i], notFoundIds[i]) } } } } func TestGetUsersInvalidUid(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } getUsersResult, err := client.GetUsers( context.Background(), []UserIdentifier{&UIDIdentifier{"too long " + strings.Repeat(".", 128)}}) want := "uid string must not be longer than 128 characters" if getUsersResult != nil || err == nil || err.Error() != want { t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) } } func TestGetUsersInvalidEmail(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } getUsersResult, err := client.GetUsers( context.Background(), []UserIdentifier{EmailIdentifier{"invalid email addr"}}) want := `malformed email string: "invalid email addr"` if getUsersResult != nil || err == nil || err.Error() != want { t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) } } func TestGetUsersInvalidPhoneNumber(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } getUsersResult, err := client.GetUsers(context.Background(), []UserIdentifier{ PhoneIdentifier{"invalid phone number"}, }) want := "phone number must be a valid, E.164 compliant identifier" if getUsersResult != nil || err == nil || err.Error() != want { t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) } } func TestGetUsersInvalidProvider(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } getUsersResult, err := client.GetUsers(context.Background(), []UserIdentifier{ ProviderIdentifier{ProviderID: "", ProviderUID: ""}, }) want := "providerID must be a non-empty string" if getUsersResult != nil || err == nil || err.Error() != want { t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) } } func TestGetUsersSingleBadIdentifier(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } identifiers := []UserIdentifier{ UIDIdentifier{"valid_id1"}, UIDIdentifier{"valid_id2"}, UIDIdentifier{"invalid id; too long. " + strings.Repeat(".", 128)}, UIDIdentifier{"valid_id3"}, UIDIdentifier{"valid_id4"}, } getUsersResult, err := client.GetUsers(context.Background(), identifiers) want := "uid string must not be longer than 128 characters" if getUsersResult != nil || err == nil || err.Error() != want { t.Errorf("GetUsers() = (%v, %q); want = (nil, %q)", getUsersResult, err, want) } } func TestGetUsersMultipleIdentifierTypes(t *testing.T) { mockUsers := []byte(` { "users": [{ "localId": "uid1", "email": "user1@example.com", "phoneNumber": "+15555550001" }, { "localId": "uid2", "email": "user2@example.com", "phoneNumber": "+15555550002" }, { "localId": "uid3", "email": "user3@example.com", "phoneNumber": "+15555550003" }, { "localId": "uid4", "email": "user4@example.com", "phoneNumber": "+15555550004", "providerUserInfo": [{ "providerId": "google.com", "rawId": "google_uid4" }] }] }`) s := echoServer(mockUsers, t) defer s.Close() identifiers := []UserIdentifier{ &UIDIdentifier{"uid1"}, &EmailIdentifier{"user2@example.com"}, &PhoneIdentifier{"+15555550003"}, &ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_uid4"}, &UIDIdentifier{"this-user-doesnt-exist"}, } getUsersResult, err := s.Client.GetUsers(context.Background(), identifiers) if err != nil { t.Fatalf("GetUsers() = %q", err) } if !sameUsers(getUsersResult.Users, []string{"uid1", "uid2", "uid3", "uid4"}) { t.Errorf("GetUsers() = %v; want = (uids from) %v (in any order)", getUsersResult.Users, []string{"uid1", "uid2", "uid3", "uid4"}) } if len(getUsersResult.NotFound) != 1 { t.Errorf("GetUsers() = %d; want = 1", len(getUsersResult.NotFound)) } else { if id, ok := getUsersResult.NotFound[0].(*UIDIdentifier); !ok { t.Errorf("GetUsers().NotFound[0] not a UIDIdentifier") } else { if id.UID != "this-user-doesnt-exist" { t.Errorf("GetUsers().NotFound[0].UID = %s; want = 'this-user-doesnt-exist'", id.UID) } } } } func TestGetNonExistingUser(t *testing.T) { resp := `{ "kind" : "identitytoolkit#GetAccountInfoResponse", "users" : [] }` s := echoServer([]byte(resp), t) defer s.Close() we := `no user exists with the uid: "id-nonexisting"` user, err := s.Client.GetUser(context.Background(), "id-nonexisting") if user != nil || err == nil || err.Error() != we || !IsUserNotFound(err) { t.Errorf("GetUser(non-existing) = (%v, %q); want = (nil, %q)", user, err, we) } we = `no user exists with the email: "foo@bar.nonexisting"` user, err = s.Client.GetUserByEmail(context.Background(), "foo@bar.nonexisting") if user != nil || err == nil || err.Error() != we || !IsUserNotFound(err) { t.Errorf("GetUserByEmail(non-existing) = (%v, %q); want = (nil, %q)", user, err, we) } we = `no user exists with the phone number: "+12345678901"` user, err = s.Client.GetUserByPhoneNumber(context.Background(), "+12345678901") if user != nil || err == nil || err.Error() != we || !IsUserNotFound(err) { t.Errorf("GetUserPhoneNumber(non-existing) = (%v, %q); want = (nil, %q)", user, err, we) } } func TestListUsers(t *testing.T) { testListUsersResponse, err := ioutil.ReadFile("../testdata/list_users.json") if err != nil { t.Fatal(err) } s := echoServer(testListUsersResponse, t) defer s.Close() want := []*ExportedUserRecord{ {UserRecord: testUser, PasswordHash: "passwordhash1", PasswordSalt: "salt1"}, {UserRecord: testUser, PasswordHash: "passwordhash2", PasswordSalt: "salt2"}, {UserRecord: testUserWithoutMFA, PasswordHash: "passwordhash3", PasswordSalt: "salt3"}, } testIterator := func(iter *UserIterator, token string, req string) { count := 0 for i := 0; i < len(want); i++ { user, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } if !reflect.DeepEqual(user.UserRecord, want[i].UserRecord) { t.Errorf("Users(%q) = %#v; want = %#v", token, user, want[i]) } if user.PasswordHash != want[i].PasswordHash { t.Errorf("Users(%q) PasswordHash = %q; want = %q", token, user.PasswordHash, want[i].PasswordHash) } if user.PasswordSalt != want[i].PasswordSalt { t.Errorf("Users(%q) PasswordSalt = %q; want = %q", token, user.PasswordSalt, want[i].PasswordSalt) } count++ } if count != len(want) { t.Errorf("Users(%q) = %d; want = %d", token, count, len(want)) } if _, err := iter.Next(); err != iterator.Done { t.Errorf("Users(%q) = %v, want = %v", token, err, iterator.Done) } hr := s.Req[len(s.Req)-1] // Check the query string of the last HTTP request made. gotReq := hr.URL.Query().Encode() if gotReq != req { t.Errorf("Users(%q) = %q, want = %v", token, gotReq, req) } wantPath := "/projects/mock-project-id/accounts:batchGet" if hr.URL.Path != wantPath { t.Errorf("Users(%q) URL = %q; want = %q", token, hr.URL.Path, wantPath) } } testIterator( s.Client.Users(context.Background(), ""), "", "maxResults=1000") testIterator( s.Client.Users(context.Background(), "pageToken"), "pageToken", "maxResults=1000&nextPageToken=pageToken") } func TestInvalidCreateUser(t *testing.T) { cases := []struct { params *UserToCreate want string }{ { (&UserToCreate{}).Password("short"), "password must be a string at least 6 characters long", }, { (&UserToCreate{}).PhoneNumber(""), "phone number must be a non-empty string", }, { (&UserToCreate{}).PhoneNumber("1234"), "phone number must be a valid, E.164 compliant identifier", }, { (&UserToCreate{}).PhoneNumber("+_!@#$"), "phone number must be a valid, E.164 compliant identifier", }, { (&UserToCreate{}).UID(""), "uid must be a non-empty string", }, { (&UserToCreate{}).UID(strings.Repeat("a", 129)), "uid string must not be longer than 128 characters", }, { (&UserToCreate{}).DisplayName(""), "display name must be a non-empty string", }, { (&UserToCreate{}).PhotoURL(""), "photo url must be a non-empty string", }, { (&UserToCreate{}).Email(""), "email must be a non-empty string", }, { (&UserToCreate{}).Email("a"), `malformed email string: "a"`, }, { (&UserToCreate{}).Email("a@"), `malformed email string: "a@"`, }, { (&UserToCreate{}).Email("@a"), `malformed email string: "@a"`, }, { (&UserToCreate{}).Email("a@a@a"), `malformed email string: "a@a@a"`, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { UID: "EnrollmentID", Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "Spouse's phone number", FactorID: "phone", }, }, }), `"uid" is not supported when adding second factors via "createUser()"`, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "invalid", }, DisplayName: "Spouse's phone number", FactorID: "phone", }, }, }), `the second factor "phoneNumber" for "invalid" must be a non-empty E.164 standard compliant identifier string`, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "Spouse's phone number", FactorID: "phone", EnrollmentTimestamp: time.Now().UTC().Unix(), }, }, }), `"EnrollmentTimeStamp" is not supported when adding second factors via "createUser()"`, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "Spouse's phone number", FactorID: "", }, }, }), `no factor id specified`, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, FactorID: "phone", }, }, }), `the second factor "displayName" for "" must be a valid non-empty string`, }, } client := &Client{ baseClient: &baseClient{}, } for i, tc := range cases { user, err := client.CreateUser(context.Background(), tc.params) if user != nil || err == nil { t.Errorf("[%d] CreateUser() = (%v, %v); want = (nil, error)", i, user, err) } if err.Error() != tc.want { t.Errorf("[%d] CreateUser() = %v; want = %v", i, err.Error(), tc.want) } } } var createUserCases = []struct { params *UserToCreate req map[string]interface{} }{ { nil, map[string]interface{}{}, }, { &UserToCreate{}, map[string]interface{}{}, }, { (&UserToCreate{}).Password("123456"), map[string]interface{}{"password": "123456"}, }, { (&UserToCreate{}).UID("1"), map[string]interface{}{"localId": "1"}, }, { (&UserToCreate{}).UID(strings.Repeat("a", 128)), map[string]interface{}{"localId": strings.Repeat("a", 128)}, }, { (&UserToCreate{}).PhoneNumber("+1"), map[string]interface{}{"phoneNumber": "+1"}, }, { (&UserToCreate{}).DisplayName("a"), map[string]interface{}{"displayName": "a"}, }, { (&UserToCreate{}).Email("a@a"), map[string]interface{}{"email": "a@a"}, }, { (&UserToCreate{}).Disabled(true), map[string]interface{}{"disabled": true}, }, { (&UserToCreate{}).Disabled(false), map[string]interface{}{"disabled": false}, }, { (&UserToCreate{}).EmailVerified(true), map[string]interface{}{"emailVerified": true}, }, { (&UserToCreate{}).EmailVerified(false), map[string]interface{}{"emailVerified": false}, }, { (&UserToCreate{}).PhotoURL("http://some.url"), map[string]interface{}{"photoUrl": "http://some.url"}, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "Phone Number active", FactorID: "phone", }, { PhoneNumber: "+11234567890", DisplayName: "Phone Number deprecated", FactorID: "phone", }, }, }), map[string]interface{}{"mfaInfo": []*multiFactorInfoResponse{ { PhoneInfo: "+11234567890", DisplayName: "Phone Number active", }, { PhoneInfo: "+11234567890", DisplayName: "Phone Number deprecated", }, }, }, }, { (&UserToCreate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "number1", FactorID: "phone", }, { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "number2", FactorID: "phone", }, }, }), map[string]interface{}{"mfaInfo": []*multiFactorInfoResponse{ { PhoneInfo: "+11234567890", DisplayName: "number1", }, { PhoneInfo: "+11234567890", DisplayName: "number2", }, }, }, }, } func TestCreateUser(t *testing.T) { resp := `{ "kind": "identitytoolkit#SignupNewUserResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() wantPath := "/projects/mock-project-id/accounts" for _, tc := range createUserCases { uid, err := s.Client.createUser(context.Background(), tc.params) if uid != "expectedUserID" || err != nil { t.Errorf("createUser(%#v) = (%q, %v); want = (%q, nil)", tc.params, uid, err, "expectedUserID") } want, err := json.Marshal(tc.req) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(s.Rbody, want) { t.Errorf("createUser(%#v) request = %v; want = %v", tc.params, string(s.Rbody), string(want)) } if s.Req[0].RequestURI != wantPath { t.Errorf("createUser(%#v) URL = %q; want = %q", tc.params, s.Req[0].RequestURI, wantPath) } } } func TestInvalidUpdateUser(t *testing.T) { cases := []struct { params *UserToUpdate want string }{ { nil, "update parameters must not be nil or empty", }, { &UserToUpdate{}, "update parameters must not be nil or empty", }, { (&UserToUpdate{}).Email(""), "email must be a non-empty string", }, { (&UserToUpdate{}).Email("invalid"), `malformed email string: "invalid"`, }, { (&UserToUpdate{}).PhoneNumber("1"), "phone number must be a valid, E.164 compliant identifier", }, { (&UserToUpdate{}).CustomClaims(map[string]interface{}{"a": strings.Repeat("a", 993)}), "serialized custom claims must not exceed 1000 characters", }, { (&UserToUpdate{}).Password("short"), "password must be a string at least 6 characters long", }, { (&UserToUpdate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { UID: "enrolledSecondFactor1", Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, FactorID: "phone", }, }, }), `the second factor "displayName" for "" must be a valid non-empty string`, }, { (&UserToUpdate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { UID: "enrolledSecondFactor1", Phone: &PhoneMultiFactorInfo{ PhoneNumber: "invalid", }, DisplayName: "Spouse's phone number", FactorID: "phone", }, }, }), `the second factor "phoneNumber" for "invalid" must be a non-empty E.164 standard compliant identifier string`, }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{UID: "google_uid"}), "user provider must specify a provider ID", }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{ProviderID: "google.com"}), "user provider must specify a uid", }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{ProviderID: "google.com", UID: ""}), "user provider must specify a uid", }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{ProviderID: "", UID: "google_uid"}), "user provider must specify a provider ID", }, { (&UserToUpdate{}).ProvidersToDelete([]string{""}), "providersToDelete must not include empty strings", }, { (&UserToUpdate{}). Email("user@example.com"). ProviderToLink(&UserProvider{ ProviderID: "email", UID: "user@example.com", }), "both UserToUpdate.Email and UserToUpdate.ProviderToLink.ProviderID='email' " + "were set; to link to the email/password provider, only specify the " + "UserToUpdate.Email field", }, { (&UserToUpdate{}). PhoneNumber("+15555550001"). ProviderToLink(&UserProvider{ ProviderID: "phone", UID: "+15555550001", }), "both UserToUpdate.PhoneNumber and UserToUpdate.ProviderToLink.ProviderID='phone' " + "were set; to link to the phone provider, only specify the " + "UserToUpdate.PhoneNumber field", }, { (&UserToUpdate{}). PhoneNumber(""). ProvidersToDelete([]string{"phone"}), "both UserToUpdate.PhoneNumber='' and " + "UserToUpdate.ProvidersToDelete=['phone'] were set; to unlink from a " + "phone provider, only specify the UserToUpdate.PhoneNumber='' field", }, } for _, claim := range reservedClaims { s := struct { params *UserToUpdate want string }{ (&UserToUpdate{}).CustomClaims(map[string]interface{}{claim: true}), fmt.Sprintf("claim %q is reserved and must not be set", claim), } cases = append(cases, s) } client := &Client{ baseClient: &baseClient{}, } for i, tc := range cases { user, err := client.UpdateUser(context.Background(), "uid", tc.params) if user != nil || err == nil { t.Errorf("[%d] UpdateUser() = (%v, %v); want = (nil, error)", i, user, err) } if err.Error() != tc.want { t.Errorf("[%d] UpdateUser() = %v; want = %v", i, err.Error(), tc.want) } } } func TestUpdateUserEmptyUID(t *testing.T) { params := (&UserToUpdate{}).DisplayName("test") client := &Client{ baseClient: &baseClient{}, } user, err := client.UpdateUser(context.Background(), "", params) if user != nil || err == nil { t.Errorf("UpdateUser('') = (%v, %v); want = (nil, error)", user, err) } } var updateUserCases = []struct { params *UserToUpdate req map[string]interface{} }{ { (&UserToUpdate{}).Password("123456"), map[string]interface{}{"password": "123456"}, }, { (&UserToUpdate{}).PhoneNumber("+1"), map[string]interface{}{"phoneNumber": "+1"}, }, { (&UserToUpdate{}).DisplayName("a"), map[string]interface{}{"displayName": "a"}, }, { (&UserToUpdate{}).Email("a@a"), map[string]interface{}{"email": "a@a"}, }, { (&UserToUpdate{}).Disabled(true), map[string]interface{}{"disableUser": true}, }, { (&UserToUpdate{}).Disabled(false), map[string]interface{}{"disableUser": false}, }, { (&UserToUpdate{}).EmailVerified(true), map[string]interface{}{"emailVerified": true}, }, { (&UserToUpdate{}).EmailVerified(false), map[string]interface{}{"emailVerified": false}, }, { (&UserToUpdate{}).PhotoURL("http://some.url"), map[string]interface{}{"photoUrl": "http://some.url"}, }, { (&UserToUpdate{}).DisplayName(""), map[string]interface{}{"deleteAttribute": []string{"DISPLAY_NAME"}}, }, { (&UserToUpdate{}).PhoneNumber(""), map[string]interface{}{"deleteProvider": []string{"phone"}}, }, { (&UserToUpdate{}).PhotoURL(""), map[string]interface{}{"deleteAttribute": []string{"PHOTO_URL"}}, }, { (&UserToUpdate{}).PhotoURL("").PhoneNumber("").DisplayName(""), map[string]interface{}{ "deleteAttribute": []string{"DISPLAY_NAME", "PHOTO_URL"}, "deleteProvider": []string{"phone"}, }, }, { (&UserToUpdate{}).MFASettings(MultiFactorSettings{ EnrolledFactors: []*MultiFactorInfo{ { UID: "enrolledSecondFactor1", Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "Spouse's phone number", FactorID: "phone", EnrollmentTimestamp: time.Now().Unix(), }, { UID: "enrolledSecondFactor2", Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, PhoneNumber: "+11234567890", DisplayName: "Spouse's phone number", FactorID: "phone", }, { Phone: &PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, PhoneNumber: "+11234567890", DisplayName: "Spouse's phone number", FactorID: "phone", }, }, }), map[string]interface{}{"mfa": multiFactorEnrollments{Enrollments: []*multiFactorInfoResponse{ { MFAEnrollmentID: "enrolledSecondFactor1", PhoneInfo: "+11234567890", DisplayName: "Spouse's phone number", EnrolledAt: time.Now().Format("2006-01-02T15:04:05Z07:00Z"), }, { MFAEnrollmentID: "enrolledSecondFactor2", DisplayName: "Spouse's phone number", PhoneInfo: "+11234567890", }, { DisplayName: "Spouse's phone number", PhoneInfo: "+11234567890", }, }}, }, }, { (&UserToUpdate{}).MFASettings(MultiFactorSettings{}), map[string]interface{}{"mfa": multiFactorEnrollments{Enrollments: nil}}, }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{ ProviderID: "google.com", UID: "google_uid", }), map[string]interface{}{ "linkProviderUserInfo": &UserProvider{ ProviderID: "google.com", UID: "google_uid", }}, }, { (&UserToUpdate{}).PhoneNumber("").ProvidersToDelete([]string{"google.com"}), map[string]interface{}{ "deleteProvider": []string{"phone", "google.com"}, }, }, { (&UserToUpdate{}).ProvidersToDelete([]string{"email", "phone"}), map[string]interface{}{ "deleteProvider": []string{"email", "phone"}, }, }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{ ProviderID: "email", UID: "user@example.com", }), map[string]interface{}{"email": "user@example.com"}, }, { (&UserToUpdate{}).ProviderToLink(&UserProvider{ ProviderID: "phone", UID: "+15555550001", }), map[string]interface{}{"phoneNumber": "+15555550001"}, }, { (&UserToUpdate{}).CustomClaims(map[string]interface{}{"a": strings.Repeat("a", 992)}), map[string]interface{}{"customAttributes": fmt.Sprintf(`{"a":%q}`, strings.Repeat("a", 992))}, }, { (&UserToUpdate{}).CustomClaims(map[string]interface{}{}), map[string]interface{}{"customAttributes": "{}"}, }, { (&UserToUpdate{}).CustomClaims(nil), map[string]interface{}{"customAttributes": "{}"}, }, } func TestUpdateUser(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() wantPath := "/projects/mock-project-id/accounts:update" for _, tc := range updateUserCases { err := s.Client.updateUser(context.Background(), "uid", tc.params) if err != nil { t.Errorf("updateUser(%v) = %v; want = nil", tc.params, err) } tc.req["localId"] = "uid" want, err := json.Marshal(tc.req) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(s.Rbody, want) { t.Errorf("updateUser() request = %v; want = %v", string(s.Rbody), string(want)) } if s.Req[0].RequestURI != wantPath { t.Errorf("updateUser(%#v) URL = %q; want = %q", tc.params, s.Req[0].RequestURI, wantPath) } } } func TestRevokeRefreshTokens(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() before := time.Now().Unix() if err := s.Client.RevokeRefreshTokens(context.Background(), "some_uid"); err != nil { t.Error(err) } after := time.Now().Unix() var req struct { ValidSince string `json:"validSince"` } if err := json.Unmarshal(s.Rbody, &req); err != nil { t.Fatal(err) } validSince, err := strconv.ParseInt(req.ValidSince, 10, 64) if err != nil { t.Fatal(err) } if validSince > after || validSince < before { t.Errorf("validSince = %d, expecting time between %d and %d", validSince, before, after) } wantPath := "/projects/mock-project-id/accounts:update" if s.Req[0].RequestURI != wantPath { t.Errorf("RevokeRefreshTokens() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestRevokeRefreshTokensInvalidUID(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() we := "uid must be a non-empty string" if err := s.Client.RevokeRefreshTokens(context.Background(), ""); err == nil || err.Error() != we { t.Errorf("RevokeRefreshTokens(); err = %s; want err = %s", err.Error(), we) } } func TestInvalidSetCustomClaims(t *testing.T) { cases := []struct { cc map[string]interface{} want string }{ { map[string]interface{}{"a": strings.Repeat("a", 993)}, "serialized custom claims must not exceed 1000 characters", }, { map[string]interface{}{"a": func() {}}, "custom claims marshaling error: json: unsupported type: func()", }, } for _, res := range reservedClaims { s := struct { cc map[string]interface{} want string }{ map[string]interface{}{res: true}, fmt.Sprintf("claim %q is reserved and must not be set", res), } cases = append(cases, s) } client := &Client{ baseClient: &baseClient{}, } for _, tc := range cases { err := client.SetCustomUserClaims(context.Background(), "uid", tc.cc) if err == nil { t.Errorf("SetCustomUserClaims() = nil; want error: %s", tc.want) } if err.Error() != tc.want { t.Errorf("SetCustomUserClaims() = %q; want = %q", err.Error(), tc.want) } } } var setCustomUserClaimsCases = []map[string]interface{}{ nil, {}, {"admin": true}, {"admin": true, "package": "gold"}, } func TestSetCustomUserClaims(t *testing.T) { resp := `{ "kind": "identitytoolkit#SetAccountInfoResponse", "localId": "uid" }` s := echoServer([]byte(resp), t) defer s.Close() wantPath := "/projects/mock-project-id/accounts:update" for _, tc := range setCustomUserClaimsCases { err := s.Client.SetCustomUserClaims(context.Background(), "uid", tc) if err != nil { t.Errorf("SetCustomUserClaims(%v) = %v; want nil", tc, err) } input := tc if input == nil { input = map[string]interface{}{} } b, err := json.Marshal(input) if err != nil { t.Fatal(err) } m := map[string]interface{}{ "localId": "uid", "customAttributes": string(b), } want, err := json.Marshal(m) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(s.Rbody, want) { t.Errorf("SetCustomUserClaims() = %v; want = %v", string(s.Rbody), string(want)) } hr := s.Req[len(s.Req)-1] if hr.RequestURI != wantPath { t.Errorf("RevokeRefreshTokens() URL = %q; want = %q", hr.RequestURI, wantPath) } } } func TestUserProvider(t *testing.T) { cases := []struct { provider *UserProvider want map[string]interface{} }{ { provider: &UserProvider{UID: "test", ProviderID: "google.com"}, want: map[string]interface{}{"rawId": "test", "providerId": "google.com"}, }, { provider: &UserProvider{ UID: "test", ProviderID: "google.com", DisplayName: "Test User", }, want: map[string]interface{}{ "rawId": "test", "providerId": "google.com", "displayName": "Test User", }, }, { provider: &UserProvider{ UID: "test", ProviderID: "google.com", DisplayName: "Test User", Email: "test@example.com", PhotoURL: "https://test.com/user.png", }, want: map[string]interface{}{ "rawId": "test", "providerId": "google.com", "displayName": "Test User", "email": "test@example.com", "photoUrl": "https://test.com/user.png", }, }, } for idx, tc := range cases { b, err := json.Marshal(tc.provider) if err != nil { t.Fatal(err) } var got map[string]interface{} if err := json.Unmarshal(b, &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(got, tc.want) { t.Errorf("[%d] UserProvider = %#v; want = %#v", idx, got, tc.want) } } } func TestUserToImport(t *testing.T) { cases := []struct { user *UserToImport want map[string]interface{} }{ { user: (&UserToImport{}).UID("test"), want: map[string]interface{}{ "localId": "test", }, }, { user: (&UserToImport{}).UID("test").DisplayName("name"), want: map[string]interface{}{ "localId": "test", "displayName": "name", }, }, { user: (&UserToImport{}).UID("test").Email("test@example.com"), want: map[string]interface{}{ "localId": "test", "email": "test@example.com", }, }, { user: (&UserToImport{}).UID("test").PhotoURL("https://test.com/user.png"), want: map[string]interface{}{ "localId": "test", "photoUrl": "https://test.com/user.png", }, }, { user: (&UserToImport{}).UID("test").PhoneNumber("+1234567890"), want: map[string]interface{}{ "localId": "test", "phoneNumber": "+1234567890", }, }, { user: (&UserToImport{}).UID("test").Metadata(&UserMetadata{ CreationTimestamp: int64(100), LastLogInTimestamp: int64(150), }), want: map[string]interface{}{ "localId": "test", "createdAt": int64(100), "lastLoginAt": int64(150), }, }, { user: (&UserToImport{}).UID("test").Metadata(&UserMetadata{ CreationTimestamp: int64(100), }), want: map[string]interface{}{ "localId": "test", "createdAt": int64(100), }, }, { user: (&UserToImport{}).UID("test").Metadata(&UserMetadata{ LastLogInTimestamp: int64(150), }), want: map[string]interface{}{ "localId": "test", "lastLoginAt": int64(150), }, }, { user: (&UserToImport{}).UID("test").PasswordHash([]byte("password")), want: map[string]interface{}{ "localId": "test", "passwordHash": base64.RawURLEncoding.EncodeToString([]byte("password")), }, }, { user: (&UserToImport{}).UID("test").PasswordSalt([]byte("nacl")), want: map[string]interface{}{ "localId": "test", "salt": base64.RawURLEncoding.EncodeToString([]byte("nacl")), }, }, { user: (&UserToImport{}).UID("test").CustomClaims(map[string]interface{}{"admin": true}), want: map[string]interface{}{ "localId": "test", "customAttributes": `{"admin":true}`, }, }, { user: (&UserToImport{}).UID("test").CustomClaims(map[string]interface{}{}), want: map[string]interface{}{ "localId": "test", }, }, { user: (&UserToImport{}).UID("test").ProviderData([]*UserProvider{ { ProviderID: "google.com", UID: "test", }, }), want: map[string]interface{}{ "localId": "test", "providerUserInfo": []*UserProvider{ { ProviderID: "google.com", UID: "test", }, }, }, }, { user: (&UserToImport{}).UID("test").EmailVerified(true), want: map[string]interface{}{ "localId": "test", "emailVerified": true, }, }, { user: (&UserToImport{}).UID("test").EmailVerified(false), want: map[string]interface{}{ "localId": "test", "emailVerified": false, }, }, { user: (&UserToImport{}).UID("test").Disabled(true), want: map[string]interface{}{ "localId": "test", "disabled": true, }, }, { user: (&UserToImport{}).UID("test").Disabled(false), want: map[string]interface{}{ "localId": "test", "disabled": false, }, }, } for idx, tc := range cases { got, err := tc.user.validatedUserInfo() if err != nil { t.Errorf("[%d] invalid user: %v", idx, err) } if !reflect.DeepEqual(got, tc.want) { t.Errorf("[%d] UserToImport = %#v; want = %#v", idx, got, tc.want) } } } func TestUserToImportError(t *testing.T) { cases := []struct { user *UserToImport want string }{ { &UserToImport{}, "no parameters are set on the user to import", }, { (&UserToImport{}).UID(""), "uid must be a non-empty string", }, { (&UserToImport{}).UID(strings.Repeat("a", 129)), "uid string must not be longer than 128 characters", }, { (&UserToImport{}).UID("test").Email("not-an-email"), `malformed email string: "not-an-email"`, }, { (&UserToImport{}).UID("test").PhoneNumber("not-a-phone"), "phone number must be a valid, E.164 compliant identifier", }, { (&UserToImport{}).UID("test").CustomClaims(map[string]interface{}{"key": strings.Repeat("a", 1000)}), "serialized custom claims must not exceed 1000 characters", }, { (&UserToImport{}).UID("test").ProviderData([]*UserProvider{ { UID: "test", }, }), "user provider must specify a provider ID", }, { (&UserToImport{}).UID("test").ProviderData([]*UserProvider{ { ProviderID: "google.com", }, }), "user provider must specify a uid", }, } s := echoServer([]byte("{}"), t) defer s.Close() for idx, tc := range cases { _, err := s.Client.ImportUsers(context.Background(), []*UserToImport{tc.user}) if err == nil || err.Error() != tc.want { t.Errorf("[%d] UserToImport = %v; want = %q", idx, err, tc.want) } } } func TestInvalidImportUsers(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() result, err := s.Client.ImportUsers(context.Background(), nil) if result != nil || err == nil { t.Errorf("ImportUsers(nil) = (%v, %v); want = (nil, error)", result, err) } result, err = s.Client.ImportUsers(context.Background(), []*UserToImport{}) if result != nil || err == nil { t.Errorf("ImportUsers([]) = (%v, %v); want = (nil, error)", result, err) } var users []*UserToImport for i := 0; i < 1001; i++ { users = append(users, (&UserToImport{}).UID(fmt.Sprintf("user%d", i))) } result, err = s.Client.ImportUsers(context.Background(), users) if result != nil || err == nil { t.Errorf("ImportUsers(len > 1000) = (%v, %v); want = (nil, error)", result, err) } } func TestImportUsers(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() users := []*UserToImport{ (&UserToImport{}).UID("user1"), (&UserToImport{}).UID("user2"), } result, err := s.Client.ImportUsers(context.Background(), users) if err != nil { t.Fatal(err) } if result.SuccessCount != 2 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 2, FailureCount: 0}", result) } wantPath := "/projects/mock-project-id/accounts:batchCreate" if s.Req[0].RequestURI != wantPath { t.Errorf("ImportUsers() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestImportUsersError(t *testing.T) { resp := `{ "error": [ {"index": 0, "message": "Some error occurred in user1"}, {"index": 2, "message": "Another error occurred in user3"} ] }` s := echoServer([]byte(resp), t) defer s.Close() users := []*UserToImport{ (&UserToImport{}).UID("user1"), (&UserToImport{}).UID("user2"), (&UserToImport{}).UID("user3"), } result, err := s.Client.ImportUsers(context.Background(), users) if err != nil { t.Fatal(err) } if result.SuccessCount != 1 || result.FailureCount != 2 || len(result.Errors) != 2 { t.Fatalf("ImportUsers() = %#v; want = {SuccessCount: 1, FailureCount: 2}", result) } want := []ErrorInfo{ {Index: 0, Reason: "Some error occurred in user1"}, {Index: 2, Reason: "Another error occurred in user3"}, } for idx, we := range want { if *result.Errors[idx] != we { t.Errorf("[%d] Error = %#v; want = %#v", idx, result.Errors[idx], we) } } } type mockHash struct { key, saltSep string rounds, memoryCost int64 } func (h mockHash) Config() (internal.HashConfig, error) { return internal.HashConfig{ "hashAlgorithm": "MOCKHASH", "signerKey": h.key, "saltSeparator": h.saltSep, "rounds": h.rounds, "memoryCost": h.memoryCost, }, nil } func TestImportUsersWithHash(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() users := []*UserToImport{ (&UserToImport{}).UID("user1").PasswordHash([]byte("password")), (&UserToImport{}).UID("user2"), } result, err := s.Client.ImportUsers(context.Background(), users, WithHash(mockHash{ key: "key", saltSep: ",", rounds: 8, memoryCost: 14, })) if err != nil { t.Fatal(err) } if result.SuccessCount != 2 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 2, FailureCount: 0}", result) } var got map[string]interface{} if err := json.Unmarshal(s.Rbody, &got); err != nil { t.Fatal(err) } want := map[string]interface{}{ "hashAlgorithm": "MOCKHASH", "signerKey": "key", "saltSeparator": ",", "rounds": float64(8), "memoryCost": float64(14), } for k, v := range want { gv, ok := got[k] if !ok || gv != v { t.Errorf("ImportUsers() request(%q) = %v; want = %v", k, gv, v) } } wantPath := "/projects/mock-project-id/accounts:batchCreate" if s.Req[0].RequestURI != wantPath { t.Errorf("ImportUsers() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestImportUsersMissingRequiredHash(t *testing.T) { s := echoServer([]byte("{}"), t) defer s.Close() users := []*UserToImport{ (&UserToImport{}).UID("user1").PasswordHash([]byte("password")), (&UserToImport{}).UID("user2"), } result, err := s.Client.ImportUsers(context.Background(), users) if result != nil || err == nil { t.Fatalf("ImportUsers() = (%v, %v); want = (nil, error)", result, err) } } func TestDeleteUser(t *testing.T) { resp := `{ "kind": "identitytoolkit#SignupNewUserResponse", "email": "", "localId": "expectedUserID" }` s := echoServer([]byte(resp), t) defer s.Close() if err := s.Client.DeleteUser(context.Background(), "uid"); err != nil { t.Errorf("DeleteUser() = %v; want = nil", err) } wantPath := "/projects/mock-project-id/accounts:delete" if s.Req[0].RequestURI != wantPath { t.Errorf("DeleteUser() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) } } func TestInvalidDeleteUser(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } if err := client.DeleteUser(context.Background(), ""); err == nil { t.Errorf("DeleteUser('') = nil; want error") } } func TestDeleteUsers(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } t.Run("should succeed given an empty list", func(t *testing.T) { result, err := client.DeleteUsers(context.Background(), []string{}) if err != nil { t.Fatalf("DeleteUsers([]) error %v; want = nil", err) } if result.SuccessCount != 0 { t.Errorf("DeleteUsers([]).SuccessCount = %d; want = 0", result.SuccessCount) } if result.FailureCount != 0 { t.Errorf("DeleteUsers([]).FailureCount = %d; want = 0", result.FailureCount) } if len(result.Errors) != 0 { t.Errorf("len(DeleteUsers([]).Errors) = %d; want = 0", len(result.Errors)) } }) t.Run("should be rejected when given more than 1000 identifiers", func(t *testing.T) { uids := []string{} for i := 0; i < 1001; i++ { uids = append(uids, fmt.Sprintf("id%d", i)) } _, err := client.DeleteUsers(context.Background(), uids) if err == nil { t.Fatalf("DeleteUsers([too_many_uids]) error nil; want not nil") } if err.Error() != "`uids` parameter must have <= 1000 entries" { t.Errorf( "DeleteUsers([too_many_uids]) returned an error of '%s'; "+ "expected '`uids` parameter must have <= 1000 entries'", err.Error()) } }) t.Run("should immediately fail given an invalid id", func(t *testing.T) { tooLongUID := "too long " + strings.Repeat(".", 128) _, err := client.DeleteUsers(context.Background(), []string{tooLongUID}) if err == nil { t.Fatalf("DeleteUsers([too_long_uid]) error nil; want not nil") } if err.Error() != "uid string must not be longer than 128 characters" { t.Errorf( "DeleteUsers([too_long_uid]) returned an error of '%s'; "+ "expected 'uid string must not be longer than 128 characters'", err.Error()) } }) t.Run("should index errors correctly in result", func(t *testing.T) { resp := `{ "errors": [{ "index": 0, "localId": "uid1", "message": "Error Message 1" }, { "index": 2, "localId": "uid3", "message": "Error Message 2" }] }` s := echoServer([]byte(resp), t) defer s.Close() result, err := s.Client.DeleteUsers(context.Background(), []string{"uid1", "uid2", "uid3", "uid4"}) if err != nil { t.Fatalf("DeleteUsers([...]) error %v; want = nil", err) } if result.SuccessCount != 2 { t.Errorf("DeleteUsers([...]).SuccessCount = %d; want 2", result.SuccessCount) } if result.FailureCount != 2 { t.Errorf("DeleteUsers([...]).FailureCount = %d; want 2", result.FailureCount) } if len(result.Errors) != 2 { t.Errorf("len(DeleteUsers([...]).Errors) = %d; want 2", len(result.Errors)) } else { if result.Errors[0].Index != 0 { t.Errorf("DeleteUsers([...]).Errors[0].Index = %d; want 0", result.Errors[0].Index) } if result.Errors[0].Reason != "Error Message 1" { t.Errorf("DeleteUsers([...]).Errors[0].Reason = %s; want Error Message 1", result.Errors[0].Reason) } if result.Errors[1].Index != 2 { t.Errorf("DeleteUsers([...]).Errors[1].Index = %d; want 2", result.Errors[1].Index) } if result.Errors[1].Reason != "Error Message 2" { t.Errorf("DeleteUsers([...]).Errors[1].Reason = %s; want Error Message 2", result.Errors[1].Reason) } } }) } func TestMakeExportedUser(t *testing.T) { queryResponse := &userQueryResponse{ UID: "testuser", Email: "testuser@example.com", PhoneNumber: "+1234567890", EmailVerified: true, DisplayName: "Test User", PasswordSalt: "salt", PhotoURL: "http://www.example.com/testuser/photo.png", PasswordHash: "passwordhash", ValidSinceSeconds: 1494364393, Disabled: false, CreationTimestamp: 1234567890000, LastLogInTimestamp: 1233211232000, CustomAttributes: `{"admin": true, "package": "gold"}`, TenantID: "testTenant", ProviderUserInfo: []*UserInfo{ { ProviderID: "password", DisplayName: "Test User", PhotoURL: "http://www.example.com/testuser/photo.png", Email: "testuser@example.com", UID: "testuid", }, { ProviderID: "phone", PhoneNumber: "+1234567890", UID: "testuid", }}, MFAInfo: []*multiFactorInfoResponse{ { PhoneInfo: "+1234567890", MFAEnrollmentID: "enrolledPhoneFactor", DisplayName: "My MFA Phone", EnrolledAt: "2021-03-03T13:06:20.542896Z", }, { TOTPInfo: &TOTPInfo{}, MFAEnrollmentID: "enrolledTOTPFactor", DisplayName: "My MFA TOTP", EnrolledAt: "2021-03-03T13:06:20.542896Z", }, }, } want := &ExportedUserRecord{ UserRecord: testUser, PasswordHash: "passwordhash", PasswordSalt: "salt", } exported, err := queryResponse.makeExportedUserRecord() if err != nil { t.Fatal(err) } if !reflect.DeepEqual(exported.UserRecord, want.UserRecord) { // zero in t.Errorf("makeExportedUser() = %#v; want: %#v \n(%#v)\n(%#v)", exported.UserRecord, want.UserRecord, exported.UserMetadata, want.UserMetadata) } if exported.PasswordHash != want.PasswordHash { t.Errorf("PasswordHash = %q; want = %q", exported.PasswordHash, want.PasswordHash) } if exported.PasswordSalt != want.PasswordSalt { t.Errorf("PasswordSalt = %q; want = %q", exported.PasswordSalt, want.PasswordSalt) } } func TestUnsupportedAuthFactor(t *testing.T) { queryResponse := &userQueryResponse{ UID: "uid1", MFAInfo: []*multiFactorInfoResponse{ { MFAEnrollmentID: "enrollementId", }, }, } exported, err := queryResponse.makeExportedUserRecord() if exported != nil || err == nil { t.Errorf("makeExportedUserRecord() = (%v, %v); want = (nil, error)", exported, err) } } func TestExportedUserRecordShouldClearRedacted(t *testing.T) { queryResponse := &userQueryResponse{ UID: "uid1", PasswordHash: base64.StdEncoding.EncodeToString([]byte("REDACTED")), } exported, err := queryResponse.makeExportedUserRecord() if err != nil { t.Fatal(err) } if exported.PasswordHash != "" { t.Errorf("PasswordHash = %q; want = ''", exported.PasswordHash) } } var createSessionCookieCases = []struct { expiresIn time.Duration want float64 }{ { expiresIn: 10 * time.Minute, want: 600.0, }, { expiresIn: 300500 * time.Millisecond, want: 300.0, }, } func TestSessionCookie(t *testing.T) { resp := `{ "sessionCookie": "expectedCookie" }` s := echoServer([]byte(resp), t) defer s.Close() for _, tc := range createSessionCookieCases { cookie, err := s.Client.SessionCookie(context.Background(), "idToken", tc.expiresIn) if cookie != "expectedCookie" || err != nil { t.Errorf("SessionCookie() = (%q, %v); want = (%q, nil)", cookie, err, "expectedCookie") } wantURL := "/projects/mock-project-id:createSessionCookie" if s.Req[0].URL.Path != wantURL { t.Errorf("SesionCookie() URL = %q; want = %q", s.Req[0].URL.Path, wantURL) } var got map[string]interface{} if err := json.Unmarshal(s.Rbody, &got); err != nil { t.Fatal(err) } want := map[string]interface{}{ "idToken": "idToken", "validDuration": tc.want, } if !reflect.DeepEqual(got, want) { t.Errorf("SessionCookie(%f) request =%#v; want = %#v", tc.want, got, want) } } } func TestSessionCookieError(t *testing.T) { resp := `{ "error": { "message": "PERMISSION_DENIED" } }` s := echoServer([]byte(resp), t) defer s.Close() s.Status = http.StatusForbidden cookie, err := s.Client.SessionCookie(context.Background(), "idToken", 10*time.Minute) if cookie != "" || err == nil { t.Fatalf("SessionCookie() = (%q, %v); want = (%q, error)", cookie, err, "") } want := fmt.Sprintf("unexpected http response with status: 403\n%s", resp) if err.Error() != want || !errorutils.IsPermissionDenied(err) { t.Errorf("SessionCookie() error = %v; want = %q", err, want) } } func TestSessionCookieWithoutProjectID(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } _, err := client.SessionCookie(context.Background(), "idToken", 10*time.Minute) want := "project id not available" if err == nil || err.Error() != want { t.Errorf("SessionCookie() = %v; want = %q", err, want) } } func TestSessionCookieWithoutIDToken(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } if _, err := client.SessionCookie(context.Background(), "", 10*time.Minute); err == nil { t.Errorf("CreateSessionCookie('') = nil; want error") } } func TestSessionCookieShortExpiresIn(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } lessThanFiveMins := 5*time.Minute - time.Second if _, err := client.SessionCookie(context.Background(), "idToken", lessThanFiveMins); err == nil { t.Errorf("SessionCookie(< 5 mins) = nil; want error") } } func TestSessionCookieLongExpiresIn(t *testing.T) { client := &Client{ baseClient: &baseClient{}, } moreThanTwoWeeks := 14*24*time.Hour + time.Second if _, err := client.SessionCookie(context.Background(), "idToken", moreThanTwoWeeks); err == nil { t.Errorf("SessionCookie(> 14 days) = nil; want error") } } func TestHTTPError(t *testing.T) { s := echoServer([]byte(`{"error":"test"}`), t) defer s.Close() s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError u, err := s.Client.GetUser(context.Background(), "some uid") if u != nil || err == nil { t.Fatalf("GetUser() = (%v, %v); want = (nil, error)", u, err) } want := "unexpected http response with status: 500\n{\"error\":\"test\"}" if err.Error() != want || !errorutils.IsInternal(err) { t.Errorf("GetUser() = %v; want = %q", err, want) } } func TestHTTPErrorWithCode(t *testing.T) { errorCodes := map[string]struct { authCheck func(error) bool platformCheck func(error) bool want string }{ "CONFIGURATION_NOT_FOUND": { IsConfigurationNotFound, errorutils.IsNotFound, "no IdP configuration corresponding to the provided identifier", }, "DUPLICATE_EMAIL": { IsEmailAlreadyExists, errorutils.IsAlreadyExists, "user with the provided email already exists", }, "DUPLICATE_LOCAL_ID": { IsUIDAlreadyExists, errorutils.IsAlreadyExists, "user with the provided uid already exists", }, "EMAIL_EXISTS": { IsEmailAlreadyExists, errorutils.IsAlreadyExists, "user with the provided email already exists", }, "INVALID_DYNAMIC_LINK_DOMAIN": { IsInvalidDynamicLinkDomain, errorutils.IsInvalidArgument, "the provided dynamic link domain is not configured or authorized for the current project", }, "INVALID_HOSTING_LINK_DOMAIN": { IsInvalidHostingLinkDomain, errorutils.IsInvalidArgument, "the provided hosting link domain is not configured in Firebase Hosting or is not owned by the current project", }, "PHONE_NUMBER_EXISTS": { IsPhoneNumberAlreadyExists, errorutils.IsAlreadyExists, "user with the provided phone number already exists", }, "UNAUTHORIZED_DOMAIN": { IsUnauthorizedContinueURI, errorutils.IsInvalidArgument, "domain of the continue url is not whitelisted", }, "USER_NOT_FOUND": { IsUserNotFound, errorutils.IsNotFound, "no user record found for the given identifier", }, } s := echoServer(nil, t) defer s.Close() s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError for code, conf := range errorCodes { s.Resp = []byte(fmt.Sprintf(`{"error":{"message":"%s"}}`, code)) u, err := s.Client.GetUser(context.Background(), "some uid") if u != nil || err == nil { t.Fatalf("GetUser() = (%v, %v); want = (nil, error)", u, err) } if err.Error() != conf.want || !conf.authCheck(err) || !conf.platformCheck(err) { t.Errorf("GetUser() = %v; want = %q", err, conf.want) } } } func TestAuthErrorWithCodeAndDetails(t *testing.T) { resp := []byte(`{"error":{"message":"USER_NOT_FOUND: extra details"}}`) s := echoServer(resp, t) defer s.Close() s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError u, err := s.Client.GetUser(context.Background(), "some uid") if u != nil || err == nil { t.Fatalf("GetUser() = (%v, %v); want = (nil, error)", u, err) } want := "no user record found for the given identifier: extra details" if err.Error() != want || !IsUserNotFound(err) || !errorutils.IsNotFound(err) { t.Errorf("GetUser() = %v; want = %q", err, want) } } func TestAuthErrorWithUnknownCode(t *testing.T) { resp := `{"error":{"message":"UNKNOWN_CODE: extra details"}}` s := echoServer([]byte(resp), t) defer s.Close() s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError u, err := s.Client.GetUser(context.Background(), "some uid") if u != nil || err == nil { t.Fatalf("GetUser() = (%v, %v); want = (nil, error)", u, err) } want := fmt.Sprintf("unexpected http response with status: 500\n%s", resp) if err.Error() != want || !errorutils.IsInternal(err) { t.Errorf("GetUser() = %v; want = %q", err, want) } } func TestUnmappedHTTPError(t *testing.T) { errorCodes := map[string]struct { authCheck func(error) bool }{ "PROJECT_NOT_FOUND": { IsProjectNotFound, }, "INVALID_EMAIL": { IsInvalidEmail, }, "INSUFFICIENT_PERMISSION": { IsInsufficientPermission, }, "UNKNOWN": { IsUnknown, }, } s := echoServer(nil, t) defer s.Close() s.Client.baseClient.httpClient.RetryConfig = nil s.Status = http.StatusInternalServerError for code, conf := range errorCodes { s.Resp = []byte(fmt.Sprintf(`{"error":{"message":"%s"}}`, code)) u, err := s.Client.GetUser(context.Background(), "some uid") if u != nil || err == nil { t.Fatalf("GetUser() = (%v, %v); want = (nil, error)", u, err) } want := fmt.Sprintf("unexpected http response with status: 500\n%s", string(s.Resp)) if err.Error() != want || conf.authCheck(err) || !errorutils.IsInternal(err) { t.Errorf("GetUser() = %v; want = %q", err, want) } } } type mockAuthServer struct { Resp []byte Header map[string]string Status int Req []*http.Request Rbody []byte Srv *httptest.Server Client *Client } // echoServer takes either a []byte or a string filename, or an object. // // echoServer returns a server whose client will reply with depending on the input type: // - []byte: the []byte it got // - object: the marshalled object, in []byte form // - nil: "{}" empty json, in case we aren't interested in the returned value, just the marshalled request // // The marshalled request is available through s.rbody, s being the retuned server. // It also returns a closing functions that has to be defer closed. func echoServer(resp interface{}, t *testing.T) *mockAuthServer { var b []byte var err error switch v := resp.(type) { case nil: b = []byte("") case []byte: b = v default: if b, err = json.Marshal(resp); err != nil { t.Fatal("marshaling error") } } s := mockAuthServer{Resp: b} const testToken = "test.token" const testVersion = "test.version" handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer r.Body.Close() reqBody, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) } s.Rbody = bytes.TrimSpace(reqBody) s.Req = append(s.Req, r) gh := r.Header.Get("Authorization") wh := "Bearer " + testToken if gh != wh { t.Errorf("Authorization header = %q; want = %q", gh, wh) } gh = r.Header.Get("X-Client-Version") wh = "Go/Admin/" + testVersion if gh != wh { t.Errorf("X-Client-Version header = %q; want: %q", gh, wh) } gh = r.Header.Get("x-goog-api-client") wh = internal.GetMetricsHeader(testVersion) if gh != wh { t.Errorf("x-goog-api-client header = %q; want: %q", gh, wh) } for k, v := range s.Header { w.Header().Set(k, v) } if s.Status != 0 { w.WriteHeader(s.Status) } w.Header().Set("Content-Type", "application/json") w.Write(s.Resp) }) s.Srv = httptest.NewServer(handler) conf := &internal.AuthConfig{ Opts: optsWithTokenSource, ProjectID: "mock-project-id", Version: testVersion, } authClient, err := NewClient(context.Background(), conf) if err != nil { t.Fatal(err) } authClient.baseClient.userManagementEndpoint = s.Srv.URL authClient.baseClient.providerConfigEndpoint = s.Srv.URL authClient.TenantManager.endpoint = s.Srv.URL authClient.baseClient.projectMgtEndpoint = s.Srv.URL s.Client = authClient return &s } func (s *mockAuthServer) Close() { s.Srv.Close() } golang-google-firebase-go-4.18.0/db/000077500000000000000000000000001505612111400170675ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/db/auth_override_test.go000066400000000000000000000055771505612111400233330ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "testing" ) func TestAuthOverrideGet(t *testing.T) { mock := &mockServer{Resp: "data"} srv := mock.Start(aoClient) defer srv.Close() ref := aoClient.NewRef("peter") var got string if err := ref.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { t.Errorf("Ref(AuthOverride).Get() = %q; want = %q", got, "data") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"auth_variable_override": testAuthOverrides}, }) } func TestAuthOverrideSet(t *testing.T) { mock := &mockServer{} srv := mock.Start(aoClient) defer srv.Close() ref := aoClient.NewRef("peter") want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} if err := ref.Set(context.Background(), want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PUT", Body: serialize(want), Path: "/peter.json", Query: map[string]string{"auth_variable_override": testAuthOverrides, "print": "silent"}, }) } func TestAuthOverrideQuery(t *testing.T) { mock := &mockServer{Resp: "data"} srv := mock.Start(aoClient) defer srv.Close() ref := aoClient.NewRef("peter") var got string if err := ref.OrderByChild("foo").Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{ "auth_variable_override": testAuthOverrides, "orderBy": "\"foo\"", }, }) } func TestAuthOverrideRangeQuery(t *testing.T) { mock := &mockServer{Resp: "data"} srv := mock.Start(aoClient) defer srv.Close() ref := aoClient.NewRef("peter") var got string if err := ref.OrderByChild("foo").StartAt(1).EndAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "data" { t.Errorf("Ref(AuthOverride).OrderByChild() = %q; want = %q", got, "data") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{ "auth_variable_override": testAuthOverrides, "orderBy": "\"foo\"", "startAt": "1", "endAt": "10", }, }) } golang-google-firebase-go-4.18.0/db/db.go000066400000000000000000000150051505612111400200040ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package db contains functions for accessing the Firebase Realtime Database. package db import ( "context" "encoding/json" "errors" "fmt" "net/url" "os" "runtime" "strings" "firebase.google.com/go/v4/internal" "golang.org/x/oauth2" "google.golang.org/api/option" ) const userAgentFormat = "Firebase/HTTP/%s/%s/AdminGo" const invalidChars = "[].#$" const authVarOverride = "auth_variable_override" const emulatorDatabaseEnvVar = "FIREBASE_DATABASE_EMULATOR_HOST" const emulatorNamespaceParam = "ns" // errInvalidURL tells whether the given database url is invalid // It is invalid if it is malformed, or not of the format "host:port" var errInvalidURL = errors.New("invalid database url") var emulatorToken = &oauth2.Token{ AccessToken: "owner", } // Client is the interface for the Firebase Realtime Database service. type Client struct { hc *internal.HTTPClient dbURLConfig *dbURLConfig authOverride string } type dbURLConfig struct { // BaseURL can be either: // - a production url (https://foo-bar.firebaseio.com/) // - an emulator url (http://localhost:9000) BaseURL string // Namespace is used in for the emulator to specify the databaseName // To specify a namespace on your url, pass ns= (localhost:9000/?ns=foo-bar) Namespace string } // NewClient creates a new instance of the Firebase Database Client. // // This function can only be invoked from within the SDK. Client applications should access the // Database service through firebase.App. func NewClient(ctx context.Context, c *internal.DatabaseConfig) (*Client, error) { urlConfig, isEmulator, err := parseURLConfig(c.URL) if err != nil { return nil, err } var ao []byte if c.AuthOverride == nil || len(c.AuthOverride) > 0 { ao, err = json.Marshal(c.AuthOverride) if err != nil { return nil, err } } opts := append([]option.ClientOption{}, c.Opts...) if isEmulator { ts := oauth2.StaticTokenSource(emulatorToken) opts = append(opts, option.WithTokenSource(ts)) } ua := fmt.Sprintf(userAgentFormat, c.Version, runtime.Version()) opts = append(opts, option.WithUserAgent(ua)) hc, _, err := internal.NewHTTPClient(ctx, opts...) if err != nil { return nil, err } hc.CreateErrFn = handleRTDBError return &Client{ hc: hc, dbURLConfig: urlConfig, authOverride: string(ao), }, nil } // NewRef returns a new database reference representing the node at the specified path. func (c *Client) NewRef(path string) *Ref { segs := parsePath(path) key := "" if len(segs) > 0 { key = segs[len(segs)-1] } return &Ref{ Key: key, Path: "/" + strings.Join(segs, "/"), client: c, segs: segs, } } func (c *Client) sendAndUnmarshal( ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) { if strings.ContainsAny(req.URL, invalidChars) { return nil, fmt.Errorf("invalid path with illegal characters: %q", req.URL) } req.URL = fmt.Sprintf("%s%s.json", c.dbURLConfig.BaseURL, req.URL) if c.authOverride != "" { req.Opts = append(req.Opts, internal.WithQueryParam(authVarOverride, c.authOverride)) } if c.dbURLConfig.Namespace != "" { req.Opts = append(req.Opts, internal.WithQueryParam(emulatorNamespaceParam, c.dbURLConfig.Namespace)) } return c.hc.DoAndUnmarshal(ctx, req, v) } func parsePath(path string) []string { var segs []string for _, s := range strings.Split(path, "/") { if s != "" { segs = append(segs, s) } } return segs } func handleRTDBError(resp *internal.Response) error { err := internal.NewFirebaseError(resp) var p struct { Error string `json:"error"` } json.Unmarshal(resp.Body, &p) if p.Error != "" { err.String = fmt.Sprintf("http error status: %d; reason: %s", resp.Status, p.Error) } return err } // parseURLConfig returns the dbURLConfig for the database // dbURL may be either: // - a production url (https://foo-bar.firebaseio.com/) // - an emulator URL (localhost:9000/?ns=foo-bar) // // The following rules will apply for determining the output: // - If the url does not use an https scheme it will be assumed to be an emulator url and be used. // - else If the FIREBASE_DATABASE_EMULATOR_HOST environment variable is set it will be used. // - else the url will be assumed to be a production url and be used. func parseURLConfig(dbURL string) (*dbURLConfig, bool, error) { parsedURL, err := url.ParseRequestURI(dbURL) if err == nil && parsedURL.Scheme != "https" { cfg, err := parseEmulatorHost(dbURL, parsedURL) return cfg, true, err } environmentEmulatorURL := os.Getenv(emulatorDatabaseEnvVar) if environmentEmulatorURL != "" { parsedURL, err = url.ParseRequestURI(environmentEmulatorURL) if err != nil { return nil, false, fmt.Errorf("%s: %w", environmentEmulatorURL, errInvalidURL) } cfg, err := parseEmulatorHost(environmentEmulatorURL, parsedURL) return cfg, true, err } if err != nil { return nil, false, fmt.Errorf("%s: %w", dbURL, errInvalidURL) } return &dbURLConfig{ BaseURL: dbURL, Namespace: "", }, false, nil } func parseEmulatorHost(rawEmulatorHostURL string, parsedEmulatorHost *url.URL) (*dbURLConfig, error) { if strings.Contains(rawEmulatorHostURL, "//") { return nil, fmt.Errorf(`invalid %s: "%s". It must follow format "host:port": %w`, emulatorDatabaseEnvVar, rawEmulatorHostURL, errInvalidURL) } baseURL := strings.Replace(rawEmulatorHostURL, fmt.Sprintf("?%s", parsedEmulatorHost.RawQuery), "", -1) if parsedEmulatorHost.Scheme != "http" { baseURL = fmt.Sprintf("http://%s", baseURL) } namespace := parsedEmulatorHost.Query().Get(emulatorNamespaceParam) if namespace == "" { if strings.Contains(rawEmulatorHostURL, ".") { namespace = strings.Split(rawEmulatorHostURL, ".")[0] } if namespace == "" { return nil, fmt.Errorf(`invalid database URL: "%s". Database URL must be a valid URL to a Firebase Realtime Database instance (include ?ns= query param)`, parsedEmulatorHost) } } return &dbURLConfig{ BaseURL: baseURL, Namespace: namespace, }, nil } golang-google-firebase-go-4.18.0/db/db_test.go000066400000000000000000000305601505612111400210460ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "encoding/json" "fmt" "io/ioutil" "log" "net/http" "net/http/httptest" "net/url" "os" "reflect" "runtime" "testing" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" ) const ( testURL = "https://test-db.firebaseio.com" testEmulatorNamespace = "test-db" testEmulatorBaseURL = "http://localhost:9000" testEmulatorURL = "localhost:9000?ns=test-db" defaultMaxRetries = 1 ) var ( aoClient *Client client *Client testAuthOverrides string testref *Ref testUserAgent string testOpts = []option.ClientOption{ option.WithTokenSource(&internal.MockTokenSource{AccessToken: "mock-token"}), } ) func TestMain(m *testing.M) { var err error client, err = NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: testURL, Version: "1.2.3", AuthOverride: map[string]interface{}{}, }) if err != nil { log.Fatalln(err) } retryConfig := client.hc.RetryConfig retryConfig.MaxRetries = defaultMaxRetries retryConfig.ExpBackoffFactor = 0 ao := map[string]interface{}{"uid": "user1"} aoClient, err = NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: testURL, Version: "1.2.3", AuthOverride: ao, }) if err != nil { log.Fatalln(err) } b, err := json.Marshal(ao) if err != nil { log.Fatalln(err) } testAuthOverrides = string(b) testref = client.NewRef("peter") testUserAgent = fmt.Sprintf(userAgentFormat, "1.2.3", runtime.Version()) os.Exit(m.Run()) } func TestNewClient(t *testing.T) { cases := []*struct { Name string URL string EnvURL string ExpectedBaseURL string ExpectedNamespace string ExpectError bool }{ {Name: "production url", URL: testURL, ExpectedBaseURL: testURL, ExpectedNamespace: ""}, {Name: "emulator - success", URL: testEmulatorURL, ExpectedBaseURL: testEmulatorBaseURL, ExpectedNamespace: testEmulatorNamespace}, {Name: "emulator - missing namespace should error", URL: "localhost:9000", ExpectError: true}, {Name: "emulator - if url contains hostname it uses the primary domain", URL: "rtdb-go.emulator:9000", ExpectedBaseURL: "http://rtdb-go.emulator:9000", ExpectedNamespace: "rtdb-go"}, {Name: "emulator env - success", EnvURL: testEmulatorURL, ExpectedBaseURL: testEmulatorBaseURL, ExpectedNamespace: testEmulatorNamespace}, } for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { t.Setenv(emulatorDatabaseEnvVar, tc.EnvURL) fromEnv := os.Getenv(emulatorDatabaseEnvVar) fmt.Printf("%s", fromEnv) c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: tc.URL, AuthOverride: make(map[string]interface{}), }) if err != nil && tc.ExpectError { return } if err != nil && !tc.ExpectError { t.Fatal(err) } if err == nil && tc.ExpectError { t.Fatal("expected error") } if c.dbURLConfig.BaseURL != tc.ExpectedBaseURL { t.Errorf("NewClient().dbURLConfig.BaseURL = %q; want = %q", c.dbURLConfig.BaseURL, tc.ExpectedBaseURL) } if c.dbURLConfig.Namespace != tc.ExpectedNamespace { t.Errorf("NewClient(%v).Namespace = %q; want = %q", tc, c.dbURLConfig.Namespace, tc.ExpectedNamespace) } if c.hc == nil { t.Errorf("NewClient().hc = nil; want non-nil") } if c.authOverride != "" { t.Errorf("NewClient().ao = %q; want = %q", c.authOverride, "") } }) } } func TestNewClientAuthOverrides(t *testing.T) { cases := []*struct { Name string Params map[string]interface{} URL string ExpectedBaseURL string ExpectedNamespace string }{ {Name: "production - without override", Params: nil, URL: testURL, ExpectedBaseURL: testURL, ExpectedNamespace: ""}, {Name: "production - with override", Params: map[string]interface{}{"uid": "user1"}, URL: testURL, ExpectedBaseURL: testURL, ExpectedNamespace: ""}, {Name: "emulator - with no query params", Params: nil, URL: testEmulatorURL, ExpectedBaseURL: testEmulatorBaseURL, ExpectedNamespace: testEmulatorNamespace}, {Name: "emulator - with override", Params: map[string]interface{}{"uid": "user1"}, URL: testEmulatorURL, ExpectedBaseURL: testEmulatorBaseURL, ExpectedNamespace: testEmulatorNamespace}, } for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: tc.URL, AuthOverride: tc.Params, }) if err != nil { t.Fatal(err) } if c.dbURLConfig.BaseURL != tc.ExpectedBaseURL { t.Errorf("NewClient(%v).baseURL = %q; want = %q", tc, c.dbURLConfig.BaseURL, tc.ExpectedBaseURL) } if c.dbURLConfig.Namespace != tc.ExpectedNamespace { t.Errorf("NewClient(%v).Namespace = %q; want = %q", tc, c.dbURLConfig.Namespace, tc.ExpectedNamespace) } if c.hc == nil { t.Errorf("NewClient(%v).hc = nil; want non-nil", tc) } b, err := json.Marshal(tc.Params) if err != nil { t.Fatal(err) } if c.authOverride != string(b) { t.Errorf("NewClient(%v).ao = %q; want = %q", tc, c.authOverride, string(b)) } }) } } func TestValidURLS(t *testing.T) { cases := []string{ "https://test-db.firebaseio.com", "https://test-db.firebasedatabase.app", } for _, tc := range cases { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: tc, }) if err != nil { t.Fatal(err) } if c.dbURLConfig.BaseURL != tc { t.Errorf("NewClient(%v).url = %q; want = %q", tc, c.dbURLConfig.BaseURL, testURL) } } } func TestInvalidURL(t *testing.T) { cases := []string{ "", "foo", "http://db.firebaseio.com", "http://firebase.google.com", "http://localhost:9000", } for _, tc := range cases { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: tc, }) if c != nil || err == nil { t.Errorf("NewClient(%q) = (%v, %v); want = (nil, error)", tc, c, err) } } } func TestInvalidAuthOverride(t *testing.T) { c, err := NewClient(context.Background(), &internal.DatabaseConfig{ Opts: testOpts, URL: testURL, AuthOverride: map[string]interface{}{"uid": func() {}}, }) if c != nil || err == nil { t.Errorf("NewClient() = (%v, %v); want = (nil, error)", c, err) } } func TestNewRef(t *testing.T) { cases := []struct { Path string WantPath string WantKey string }{ {"", "/", ""}, {"/", "/", ""}, {"foo", "/foo", "foo"}, {"/foo", "/foo", "foo"}, {"foo/bar", "/foo/bar", "bar"}, {"/foo/bar", "/foo/bar", "bar"}, {"/foo/bar/", "/foo/bar", "bar"}, } for _, tc := range cases { r := client.NewRef(tc.Path) if r.client == nil { t.Errorf("NewRef(%q).client = nil; want = %v", tc.Path, r.client) } if r.Path != tc.WantPath { t.Errorf("NewRef(%q).Path = %q; want = %q", tc.Path, r.Path, tc.WantPath) } if r.Key != tc.WantKey { t.Errorf("NewRef(%q).Key = %q; want = %q", tc.Path, r.Key, tc.WantKey) } } } func TestParent(t *testing.T) { cases := []struct { Path string HasParent bool Want string }{ {"", false, ""}, {"/", false, ""}, {"foo", true, ""}, {"/foo", true, ""}, {"foo/bar", true, "foo"}, {"/foo/bar", true, "foo"}, {"/foo/bar/", true, "foo"}, } for _, tc := range cases { r := client.NewRef(tc.Path).Parent() if tc.HasParent { if r == nil { t.Fatalf("Parent(%q) = nil; want = Ref(%q)", tc.Path, tc.Want) } if r.client == nil { t.Errorf("Parent(%q).client = nil; want = %v", tc.Path, client) } if r.Key != tc.Want { t.Errorf("Parent(%q).Key = %q; want = %q", tc.Path, r.Key, tc.Want) } } else if r != nil { t.Fatalf("Parent(%q) = %v; want = nil", tc.Path, r) } } } func TestChild(t *testing.T) { r := client.NewRef("/test") cases := []struct { Path string Want string Parent string }{ {"", "/test", "/"}, {"foo", "/test/foo", "/test"}, {"/foo", "/test/foo", "/test"}, {"foo/", "/test/foo", "/test"}, {"/foo/", "/test/foo", "/test"}, {"//foo//", "/test/foo", "/test"}, {"foo/bar", "/test/foo/bar", "/test/foo"}, {"/foo/bar", "/test/foo/bar", "/test/foo"}, {"foo/bar/", "/test/foo/bar", "/test/foo"}, {"/foo/bar/", "/test/foo/bar", "/test/foo"}, {"//foo/bar", "/test/foo/bar", "/test/foo"}, {"foo//bar/", "/test/foo/bar", "/test/foo"}, {"foo/bar//", "/test/foo/bar", "/test/foo"}, } for _, tc := range cases { c := r.Child(tc.Path) if c.Path != tc.Want { t.Errorf("Child(%q) = %q; want = %q", tc.Path, c.Path, tc.Want) } if c.Parent().Path != tc.Parent { t.Errorf("Child(%q).Parent() = %q; want = %q", tc.Path, c.Parent().Path, tc.Parent) } } } func checkOnlyRequest(t *testing.T, got []*testReq, want *testReq) { checkAllRequests(t, got, []*testReq{want}) } func checkAllRequests(t *testing.T, got []*testReq, want []*testReq) { if len(got) != len(want) { t.Errorf("Request Count = %d; want = %d", len(got), len(want)) } else { for i, r := range got { checkRequest(t, r, want[i]) } } } func checkRequest(t *testing.T, got, want *testReq) { if h := got.Header.Get("Authorization"); h != "Bearer mock-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer mock-token") } if h := got.Header.Get("User-Agent"); h != testUserAgent { t.Errorf("User-Agent = %q; want = %q", h, testUserAgent) } if got.Method != want.Method { t.Errorf("Method = %q; want = %q", got.Method, want.Method) } if got.Path != want.Path { t.Errorf("Path = %q; want = %q", got.Path, want.Path) } if len(want.Query) != len(got.Query) { t.Errorf("QueryParam = %v; want = %v", got.Query, want.Query) } for k, v := range want.Query { if got.Query[k] != v { t.Errorf("QueryParam(%v) = %v; want = %v", k, got.Query[k], v) } } for k, v := range want.Header { if got.Header.Get(k) != v[0] { t.Errorf("Header(%q) = %q; want = %q", k, got.Header.Get(k), v[0]) } } if want.Body != nil { if h := got.Header.Get("Content-Type"); h != "application/json" { t.Errorf("User-Agent = %q; want = %q", h, "application/json") } var wi, gi interface{} if err := json.Unmarshal(want.Body, &wi); err != nil { t.Fatal(err) } if err := json.Unmarshal(got.Body, &gi); err != nil { t.Fatal(err) } if !reflect.DeepEqual(gi, wi) { t.Errorf("Body = %v; want = %v", gi, wi) } } else if len(got.Body) != 0 { t.Errorf("Body = %v; want empty", got.Body) } } type testReq struct { Method string Path string Header http.Header Body []byte Query map[string]string } func newTestReq(r *http.Request) (*testReq, error) { defer r.Body.Close() b, err := ioutil.ReadAll(r.Body) if err != nil { return nil, err } u, err := url.Parse(r.RequestURI) if err != nil { return nil, err } query := make(map[string]string) for k, v := range u.Query() { query[k] = v[0] } return &testReq{ Method: r.Method, Path: u.Path, Header: r.Header, Body: b, Query: query, }, nil } type mockServer struct { Resp interface{} Header map[string]string Status int Reqs []*testReq srv *httptest.Server } func (s *mockServer) Start(c *Client) *httptest.Server { if s.srv != nil { return s.srv } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr, _ := newTestReq(r) s.Reqs = append(s.Reqs, tr) for k, v := range s.Header { w.Header().Set(k, v) } print := r.URL.Query().Get("print") if s.Status != 0 { w.WriteHeader(s.Status) } else if print == "silent" { w.WriteHeader(http.StatusNoContent) return } b, _ := json.Marshal(s.Resp) w.Header().Set("Content-Type", "application/json") w.Write(b) }) s.srv = httptest.NewServer(handler) c.dbURLConfig.BaseURL = s.srv.URL return s.srv } type person struct { Name string `json:"name"` Age int32 `json:"age"` } func serialize(v interface{}) []byte { b, _ := json.Marshal(v) return b } golang-google-firebase-go-4.18.0/db/query.go000066400000000000000000000262141505612111400205700ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "encoding/json" "fmt" "net/http" "sort" "strconv" "strings" "firebase.google.com/go/v4/internal" ) // QueryNode represents a data node retrieved from an ordered query. type QueryNode interface { Key() string Unmarshal(v interface{}) error } // Query represents a complex query that can be executed on a Ref. // // Complex queries can consist of up to 2 components: a required ordering constraint, and an // optional filtering constraint. At the server, data is first sorted according to the given // ordering constraint (e.g. order by child). Then the filtering constraint (e.g. limit, range) is // applied on the sorted data to produce the final result. Despite the ordering constraint, the // final result is returned by the server as an unordered collection. Therefore the values read // from a Query instance are not ordered. type Query struct { client *Client path string order orderBy limFirst, limLast int start, end, equalTo interface{} } // StartAt returns a shallow copy of the Query with v set as a lower bound of a range query. // // The resulting Query will only return child nodes with a value greater than or equal to v. func (q *Query) StartAt(v interface{}) *Query { q2 := &Query{} *q2 = *q q2.start = v return q2 } // EndAt returns a shallow copy of the Query with v set as a upper bound of a range query. // // The resulting Query will only return child nodes with a value less than or equal to v. func (q *Query) EndAt(v interface{}) *Query { q2 := &Query{} *q2 = *q q2.end = v return q2 } // EqualTo returns a shallow copy of the Query with v set as an equals constraint. // // The resulting Query will only return child nodes whose values equal to v. func (q *Query) EqualTo(v interface{}) *Query { q2 := &Query{} *q2 = *q q2.equalTo = v return q2 } // LimitToFirst returns a shallow copy of the Query, which is anchored to the first n // elements of the window. func (q *Query) LimitToFirst(n int) *Query { q2 := &Query{} *q2 = *q q2.limFirst = n return q2 } // LimitToLast returns a shallow copy of the Query, which is anchored to the last n // elements of the window. func (q *Query) LimitToLast(n int) *Query { q2 := &Query{} *q2 = *q q2.limLast = n return q2 } // Get executes the Query and populates v with the results. // // Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and // therefore v has the same requirements as the json package. Specifically, it must be a pointer, // and must not be nil. // // Despite the ordering constraint of the Query, results are not stored in any particular order // in v. Use GetOrdered() to obtain ordered results. func (q *Query) Get(ctx context.Context, v interface{}) error { qp := make(map[string]string) if err := initQueryParams(q, qp); err != nil { return err } req := &internal.Request{ Method: http.MethodGet, URL: q.path, Opts: []internal.HTTPOption{internal.WithQueryParams(qp)}, } _, err := q.client.sendAndUnmarshal(ctx, req, v) return err } // GetOrdered executes the Query and returns the results as an ordered slice. func (q *Query) GetOrdered(ctx context.Context) ([]QueryNode, error) { var temp interface{} if err := q.Get(ctx, &temp); err != nil { return nil, err } if temp == nil { return nil, nil } sn := newSortableNodes(temp, q.order) sort.Sort(sn) result := make([]QueryNode, len(sn)) for i, v := range sn { result[i] = v } return result, nil } // OrderByChild returns a Query that orders data by child values before applying filters. // // Returned Query can be used to set additional parameters, and execute complex database queries // (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query // will inherit it. func (r *Ref) OrderByChild(child string) *Query { return newQuery(r, orderByChild(child)) } // OrderByKey returns a Query that orders data by key before applying filters. // // Returned Query can be used to set additional parameters, and execute complex database queries // (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query // will inherit it. func (r *Ref) OrderByKey() *Query { return newQuery(r, orderByProperty("$key")) } // OrderByValue returns a Query that orders data by value before applying filters. // // Returned Query can be used to set additional parameters, and execute complex database queries // (e.g. limit queries, range queries). If r has a context associated with it, the resulting Query // will inherit it. func (r *Ref) OrderByValue() *Query { return newQuery(r, orderByProperty("$value")) } func newQuery(r *Ref, ob orderBy) *Query { return &Query{ client: r.client, path: r.Path, order: ob, } } func initQueryParams(q *Query, qp map[string]string) error { ob, err := q.order.encode() if err != nil { return err } qp["orderBy"] = ob if q.limFirst > 0 && q.limLast > 0 { return fmt.Errorf("cannot set both limit parameter: first = %d, last = %d", q.limFirst, q.limLast) } else if q.limFirst < 0 { return fmt.Errorf("limit first cannot be negative: %d", q.limFirst) } else if q.limLast < 0 { return fmt.Errorf("limit last cannot be negative: %d", q.limLast) } if q.limFirst > 0 { qp["limitToFirst"] = strconv.Itoa(q.limFirst) } else if q.limLast > 0 { qp["limitToLast"] = strconv.Itoa(q.limLast) } if err := encodeFilter("startAt", q.start, qp); err != nil { return err } if err := encodeFilter("endAt", q.end, qp); err != nil { return err } return encodeFilter("equalTo", q.equalTo, qp) } func encodeFilter(key string, val interface{}, m map[string]string) error { if val == nil { return nil } b, err := json.Marshal(val) if err != nil { return err } m[key] = string(b) return nil } type orderBy interface { encode() (string, error) } type orderByChild string func (p orderByChild) encode() (string, error) { if p == "" { return "", fmt.Errorf("empty child path") } else if strings.ContainsAny(string(p), invalidChars) { return "", fmt.Errorf("invalid child path with illegal characters: %q", p) } segs := parsePath(string(p)) if len(segs) == 0 { return "", fmt.Errorf("invalid child path: %q", p) } b, err := json.Marshal(strings.Join(segs, "/")) if err != nil { return "", nil } return string(b), nil } type orderByProperty string func (p orderByProperty) encode() (string, error) { b, err := json.Marshal(p) if err != nil { return "", err } return string(b), nil } // Firebase type ordering: https://firebase.google.com/docs/database/rest/retrieve-data#section-rest-ordered-data const ( typeNull = 0 typeBoolFalse = 1 typeBoolTrue = 2 typeNumeric = 3 typeString = 4 typeObject = 5 ) // comparableKey is a union type of numeric values and strings. type comparableKey struct { Num *float64 Str *string } func (k *comparableKey) Compare(o *comparableKey) int { if k.Str != nil && o.Str != nil { return strings.Compare(*k.Str, *o.Str) } else if k.Num != nil && o.Num != nil { if *k.Num < *o.Num { return -1 } else if *k.Num == *o.Num { return 0 } return 1 } else if k.Num != nil { // numeric keys appear before string keys return -1 } return 1 } func newComparableKey(v interface{}) *comparableKey { if s, ok := v.(string); ok { return &comparableKey{Str: &s} } // Numeric values could be int (in the case of array indices and type constants), or float64 (if // the value was received as json). if i, ok := v.(int); ok { f := float64(i) return &comparableKey{Num: &f} } f := v.(float64) return &comparableKey{Num: &f} } type queryNodeImpl struct { CompKey *comparableKey Value interface{} Index interface{} IndexType int } func (q *queryNodeImpl) Key() string { if q.CompKey.Str != nil { return *q.CompKey.Str } // Numeric keys in queryNodeImpl are always array indices, and can be safely converted into int. return strconv.Itoa(int(*q.CompKey.Num)) } func (q *queryNodeImpl) Unmarshal(v interface{}) error { b, err := json.Marshal(q.Value) if err != nil { return err } return json.Unmarshal(b, v) } func newQueryNode(key, val interface{}, order orderBy) *queryNodeImpl { var index interface{} if prop, ok := order.(orderByProperty); ok { if prop == "$value" { index = val } else { index = key } } else { path := order.(orderByChild) index = extractChildValue(val, string(path)) } return &queryNodeImpl{ CompKey: newComparableKey(key), Value: val, Index: index, IndexType: getIndexType(index), } } type sortableNodes []*queryNodeImpl func (s sortableNodes) Len() int { return len(s) } func (s sortableNodes) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s sortableNodes) Less(i, j int) bool { a, b := s[i], s[j] var aKey, bKey *comparableKey if a.IndexType == b.IndexType { // If the indices have the same type and are comparable (i.e. numeric or string), compare // them directly. Otherwise, compare the keys. if (a.IndexType == typeNumeric || a.IndexType == typeString) && a.Index != b.Index { aKey, bKey = newComparableKey(a.Index), newComparableKey(b.Index) } else { aKey, bKey = a.CompKey, b.CompKey } } else { // If the indices are of different types, use the type ordering of Firebase. aKey, bKey = newComparableKey(a.IndexType), newComparableKey(b.IndexType) } return aKey.Compare(bKey) < 0 } func newSortableNodes(values interface{}, order orderBy) sortableNodes { var entries sortableNodes if m, ok := values.(map[string]interface{}); ok { for key, val := range m { entries = append(entries, newQueryNode(key, val, order)) } } else if l, ok := values.([]interface{}); ok { for key, val := range l { entries = append(entries, newQueryNode(key, val, order)) } } else { entries = append(entries, newQueryNode(0, values, order)) } return entries } // extractChildValue retrieves the value at path from val. // // If the given path does not exist in val, or val does not support child path traversal, // extractChildValue returns nil. func extractChildValue(val interface{}, path string) interface{} { segments := parsePath(path) curr := val for _, s := range segments { if curr == nil { return nil } currMap, ok := curr.(map[string]interface{}) if !ok { return nil } if curr, ok = currMap[s]; !ok { return nil } } return curr } func getIndexType(index interface{}) int { if index == nil { return typeNull } else if b, ok := index.(bool); ok { if b { return typeBoolTrue } return typeBoolFalse } else if _, ok := index.(float64); ok { return typeNumeric } else if _, ok := index.(string); ok { return typeString } return typeObject } golang-google-firebase-go-4.18.0/db/query_test.go000066400000000000000000000524261505612111400216330ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "fmt" "net/http" "reflect" "testing" "firebase.google.com/go/v4/errorutils" ) var sortableKeysResp = map[string]interface{}{ "bob": person{Name: "bob", Age: 20}, "alice": person{Name: "alice", Age: 30}, "charlie": person{Name: "charlie", Age: 15}, "dave": person{Name: "dave", Age: 25}, "ernie": person{Name: "ernie"}, } var sortableValuesResp = []struct { resp map[string]interface{} want []interface{} wantKeys []string }{ { resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 3}, want: []interface{}{1.0, 2.0, 3.0}, wantKeys: []string{"k1", "k2", "k3"}, }, { resp: map[string]interface{}{"k1": 3, "k2": 2, "k3": 1}, want: []interface{}{1.0, 2.0, 3.0}, wantKeys: []string{"k3", "k2", "k1"}, }, { resp: map[string]interface{}{"k1": 3, "k2": 1, "k3": 2}, want: []interface{}{1.0, 2.0, 3.0}, wantKeys: []string{"k2", "k3", "k1"}, }, { resp: map[string]interface{}{"k1": 1, "k2": 2, "k3": 1}, want: []interface{}{1.0, 1.0, 2.0}, wantKeys: []string{"k1", "k3", "k2"}, }, { resp: map[string]interface{}{"k1": 1, "k2": 1, "k3": 2}, want: []interface{}{1.0, 1.0, 2.0}, wantKeys: []string{"k1", "k2", "k3"}, }, { resp: map[string]interface{}{"k1": 2, "k2": 1, "k3": 1}, want: []interface{}{1.0, 1.0, 2.0}, wantKeys: []string{"k2", "k3", "k1"}, }, { resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": "baz"}, want: []interface{}{"bar", "baz", "foo"}, wantKeys: []string{"k2", "k3", "k1"}, }, { resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": 10}, want: []interface{}{10.0, "bar", "foo"}, wantKeys: []string{"k3", "k2", "k1"}, }, { resp: map[string]interface{}{"k1": "foo", "k2": "bar", "k3": nil}, want: []interface{}{nil, "bar", "foo"}, wantKeys: []string{"k3", "k2", "k1"}, }, { resp: map[string]interface{}{"k1": 5, "k2": "bar", "k3": nil}, want: []interface{}{nil, 5.0, "bar"}, wantKeys: []string{"k3", "k1", "k2"}, }, { resp: map[string]interface{}{ "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, "k6": map[string]interface{}{"k1": true}, }, want: []interface{}{false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}}, wantKeys: []string{"k5", "k1", "k2", "k3", "k4", "k6"}, }, { resp: map[string]interface{}{ "k1": true, "k2": 0, "k3": "foo", "k4": "foo", "k5": false, "k6": map[string]interface{}{"k1": true}, "k7": nil, "k8": map[string]interface{}{"k0": true}, }, want: []interface{}{ nil, false, true, 0.0, "foo", "foo", map[string]interface{}{"k1": true}, map[string]interface{}{"k0": true}, }, wantKeys: []string{"k7", "k5", "k1", "k2", "k3", "k4", "k6", "k8"}, }, } func TestChildQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() cases := []string{ "messages", "messages/", "/messages", } var reqs []*testReq for _, tc := range cases { var got map[string]interface{} if err := testref.OrderByChild(tc).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("OrderByChild(%q) = %v; want = %v", tc, got, want) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"messages\""}, }) } checkAllRequests(t, mock.Reqs, reqs) } func TestNestedChildQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByChild("messages/ratings").Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("OrderByChild(%q) = %v; want = %v", "messages/ratings", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"messages/ratings\""}, }) } func TestChildQueryWithParams(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() q := testref.OrderByChild("messages").StartAt("m4").EndAt("m50").LimitToFirst(10) var got map[string]interface{} if err := q.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("OrderByChild() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{ "orderBy": "\"messages\"", "startAt": "\"m4\"", "endAt": "\"m50\"", "limitToFirst": "10", }, }) } func TestInvalidOrderByChild(t *testing.T) { mock := &mockServer{Resp: "test"} srv := mock.Start(client) defer srv.Close() r := client.NewRef("/") cases := []string{ "", "/", "foo$", "foo.", "foo#", "foo]", "foo[", "$key", "$value", "$priority", } for _, tc := range cases { var got string if err := r.OrderByChild(tc).Get(context.Background(), &got); got != "" || err == nil { t.Errorf("OrderByChild(%q) = (%q, %v); want = (%q, error)", tc, got, err, "") } } if len(mock.Reqs) != 0 { t.Errorf("OrderByChild() = %v; want = empty", mock.Reqs) } } func TestKeyQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByKey().Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("OrderByKey() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"$key\""}, }) } func TestValueQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByValue().Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("OrderByValue() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"$value\""}, }) } func TestLimitFirstQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByChild("messages").LimitToFirst(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("LimitToFirst() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"limitToFirst": "10", "orderBy": "\"messages\""}, }) } func TestLimitLastQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByChild("messages").LimitToLast(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"limitToLast": "10", "orderBy": "\"messages\""}, }) } func TestInvalidLimitQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() q := testref.OrderByChild("messages") cases := []struct { name string q *Query }{ {"BothLimits", q.LimitToFirst(10).LimitToLast(10)}, {"NegativeFirst", q.LimitToFirst(-10)}, {"NegativeLast", q.LimitToLast(-10)}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { var got map[string]interface{} if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) } if len(mock.Reqs) != 0 { t.Errorf("OrderByChild(%q): %v; want: empty", tc.name, mock.Reqs) } }) } } func TestStartAtQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByChild("messages").StartAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("StartAt() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"startAt": "10", "orderBy": "\"messages\""}, }) } func TestEndAtQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByChild("messages").EndAt(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("EndAt() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"endAt": "10", "orderBy": "\"messages\""}, }) } func TestEqualToQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got map[string]interface{} if err := testref.OrderByChild("messages").EqualTo(10).Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("EqualTo() = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"equalTo": "10", "orderBy": "\"messages\""}, }) } func TestInvalidFilterQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() q := testref.OrderByChild("messages") cases := []struct { name string q *Query }{ {"InvalidStartAt", q.StartAt(func() {})}, {"InvalidEndAt", q.EndAt(func() {})}, {"InvalidEqualTo", q.EqualTo(func() {})}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { var got map[string]interface{} if err := tc.q.Get(context.Background(), &got); got != nil || err == nil { t.Errorf("OrderByChild(%q) = (%v, %v); want = (nil, error)", tc.name, got, err) } if len(mock.Reqs) != 0 { t.Errorf("OrdderByChild(%q) = %v; want = empty", tc.name, mock.Reqs) } }) } } func TestAllParamsQuery(t *testing.T) { want := map[string]interface{}{"m1": "Hello", "m2": "Bye"} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() q := testref.OrderByChild("messages").LimitToFirst(100).StartAt("bar").EndAt("foo") var got map[string]interface{} if err := q.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("OrderByChild(AllParams) = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{ "limitToFirst": "100", "startAt": "\"bar\"", "endAt": "\"foo\"", "orderBy": "\"messages\"", }, }) } func TestChildQueryGetOrdered(t *testing.T) { mock := &mockServer{Resp: sortableKeysResp} srv := mock.Start(client) defer srv.Close() cases := []struct { child string want []string }{ {"name", []string{"alice", "bob", "charlie", "dave", "ernie"}}, {"age", []string{"ernie", "charlie", "bob", "dave", "alice"}}, {"nonexisting", []string{"alice", "bob", "charlie", "dave", "ernie"}}, } var reqs []*testReq for idx, tc := range cases { result, err := testref.OrderByChild(tc.child).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": fmt.Sprintf("%q", tc.child)}, }) var gotKeys, gotVals []string for _, r := range result { var p person if err := r.Unmarshal(&p); err != nil { t.Fatal(err) } gotKeys = append(gotKeys, r.Key()) gotVals = append(gotVals, p.Name) } if !reflect.DeepEqual(tc.want, gotKeys) { t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotKeys, tc.want) } if !reflect.DeepEqual(tc.want, gotVals) { t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, tc.child, gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) } func TestImmediateChildQueryGetOrdered(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() type parsedMap struct { Child interface{} `json:"child"` } var reqs []*testReq for idx, tc := range sortableValuesResp { resp := map[string]interface{}{} for k, v := range tc.resp { resp[k] = map[string]interface{}{"child": v} } mock.Resp = resp result, err := testref.OrderByChild("child").GetOrdered(context.Background()) if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"child\""}, }) var gotKeys []string var gotVals []interface{} for _, r := range result { var p parsedMap if err := r.Unmarshal(&p); err != nil { t.Fatal(err) } gotKeys = append(gotKeys, r.Key()) gotVals = append(gotVals, p.Child) } if !reflect.DeepEqual(tc.wantKeys, gotKeys) { t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotKeys, tc.wantKeys) } if !reflect.DeepEqual(tc.want, gotVals) { t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child", gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) } func TestNestedChildQueryGetOrdered(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() type grandChild struct { GrandChild interface{} `json:"grandchild"` } type parsedMap struct { Child grandChild `json:"child"` } var reqs []*testReq for idx, tc := range sortableValuesResp { resp := map[string]interface{}{} for k, v := range tc.resp { resp[k] = map[string]interface{}{"child": map[string]interface{}{"grandchild": v}} } mock.Resp = resp q := testref.OrderByChild("child/grandchild") result, err := q.GetOrdered(context.Background()) if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"child/grandchild\""}, }) var gotKeys []string var gotVals []interface{} for _, r := range result { var p parsedMap if err := r.Unmarshal(&p); err != nil { t.Fatal(err) } gotKeys = append(gotKeys, r.Key()) gotVals = append(gotVals, p.Child.GrandChild) } if !reflect.DeepEqual(tc.wantKeys, gotKeys) { t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotKeys, tc.wantKeys) } if !reflect.DeepEqual(tc.want, gotVals) { t.Errorf("[%d] GetOrdered(child: %q) = %v; want = %v", idx, "child/grandchild", gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) } func TestKeyQueryGetOrdered(t *testing.T) { mock := &mockServer{Resp: sortableKeysResp} srv := mock.Start(client) defer srv.Close() result, err := testref.OrderByKey().GetOrdered(context.Background()) if err != nil { t.Fatal(err) } req := &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"$key\""}, } var gotKeys, gotVals []string for _, r := range result { var p person if err := r.Unmarshal(&p); err != nil { t.Fatal(err) } gotKeys = append(gotKeys, r.Key()) gotVals = append(gotVals, p.Name) } want := []string{"alice", "bob", "charlie", "dave", "ernie"} if !reflect.DeepEqual(want, gotKeys) { t.Errorf("GetOrdered(key) = %v; want = %v", gotKeys, want) } if !reflect.DeepEqual(want, gotVals) { t.Errorf("GetOrdered(key) = %v; want = %v", gotVals, want) } checkOnlyRequest(t, mock.Reqs, req) } func TestValueQueryGetOrdered(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() var reqs []*testReq for idx, tc := range sortableValuesResp { mock.Resp = tc.resp result, err := testref.OrderByValue().GetOrdered(context.Background()) if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"$value\""}, }) var gotKeys []string var gotVals []interface{} for _, r := range result { var v interface{} if err := r.Unmarshal(&v); err != nil { t.Fatal(err) } gotKeys = append(gotKeys, r.Key()) gotVals = append(gotVals, v) } if !reflect.DeepEqual(tc.wantKeys, gotKeys) { t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotKeys, tc.wantKeys) } if !reflect.DeepEqual(tc.want, gotVals) { t.Errorf("[%d] GetOrdered(value) = %v; want = %v", idx, gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) } func TestValueQueryGetOrderedWithList(t *testing.T) { cases := []struct { resp []interface{} want []interface{} wantKeys []string }{ { resp: []interface{}{1, 2, 3}, want: []interface{}{1.0, 2.0, 3.0}, wantKeys: []string{"0", "1", "2"}, }, { resp: []interface{}{3, 2, 1}, want: []interface{}{1.0, 2.0, 3.0}, wantKeys: []string{"2", "1", "0"}, }, { resp: []interface{}{1, 3, 2}, want: []interface{}{1.0, 2.0, 3.0}, wantKeys: []string{"0", "2", "1"}, }, { resp: []interface{}{1, 3, 3}, want: []interface{}{1.0, 3.0, 3.0}, wantKeys: []string{"0", "1", "2"}, }, { resp: []interface{}{1, 2, 1}, want: []interface{}{1.0, 1.0, 2.0}, wantKeys: []string{"0", "2", "1"}, }, { resp: []interface{}{"foo", "bar", "baz"}, want: []interface{}{"bar", "baz", "foo"}, wantKeys: []string{"1", "2", "0"}, }, { resp: []interface{}{"foo", 1, false, nil, 0, true}, want: []interface{}{nil, false, true, 0.0, 1.0, "foo"}, wantKeys: []string{"3", "2", "5", "4", "1", "0"}, }, } mock := &mockServer{} srv := mock.Start(client) defer srv.Close() var reqs []*testReq for _, tc := range cases { mock.Resp = tc.resp result, err := testref.OrderByValue().GetOrdered(context.Background()) if err != nil { t.Fatal(err) } reqs = append(reqs, &testReq{ Method: "GET", Path: "/peter.json", Query: map[string]string{"orderBy": "\"$value\""}, }) var gotKeys []string var gotVals []interface{} for _, r := range result { var v interface{} if err := r.Unmarshal(&v); err != nil { t.Fatal(err) } gotKeys = append(gotKeys, r.Key()) gotVals = append(gotVals, v) } if !reflect.DeepEqual(tc.wantKeys, gotKeys) { t.Errorf("GetOrdered(value) = %v; want = %v", gotKeys, tc.wantKeys) } if !reflect.DeepEqual(tc.want, gotVals) { t.Errorf("GetOrdered(value) = %v; want = %v", gotVals, tc.want) } } checkAllRequests(t, mock.Reqs, reqs) } func TestGetOrderedWithNilResult(t *testing.T) { mock := &mockServer{Resp: nil} srv := mock.Start(client) defer srv.Close() result, err := testref.OrderByChild("child").GetOrdered(context.Background()) if err != nil { t.Fatal(err) } if result != nil { t.Errorf("GetOrdered(value) = %v; want = nil", result) } } func TestGetOrderedWithLeafNode(t *testing.T) { mock := &mockServer{Resp: "foo"} srv := mock.Start(client) defer srv.Close() result, err := testref.OrderByChild("child").GetOrdered(context.Background()) if err != nil { t.Fatal(err) } if len(result) != 1 { t.Fatalf("GetOrdered(chid) = %d; want = 1", len(result)) } if result[0].Key() != "0" { t.Errorf("GetOrdered(value).Key() = %v; want = %q", result[0].Key(), 0) } var v interface{} if err := result[0].Unmarshal(&v); err != nil { t.Fatal(err) } if v != "foo" { t.Errorf("GetOrdered(value) = %v; want = %v", v, "foo") } } func TestQueryHttpError(t *testing.T) { mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} srv := mock.Start(client) defer srv.Close() want := "http error status: 500; reason: test error" result, err := testref.OrderByChild("child").GetOrdered(context.Background()) if result != nil || err == nil || err.Error() != want { t.Fatalf("GetOrdered() = (%v, %v); want = (nil, %v)", result, err, want) } if !errorutils.IsInternal(err) { t.Errorf("IsInternal(err) = false; want = true") } resp := errorutils.HTTPResponse(err) if resp == nil || resp.StatusCode != http.StatusInternalServerError { t.Errorf("HTTPResponse(err) = %v; want = {StatusCode: %d}", resp, http.StatusInternalServerError) } } golang-google-firebase-go-4.18.0/db/ref.go000066400000000000000000000221531505612111400201750ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "encoding/json" "fmt" "net/http" "strings" "firebase.google.com/go/v4/internal" ) // txnRetires is the maximum number of times a transaction is retried before giving up. Transaction // retries are triggered by concurrent conflicting updates to the same database location. const txnRetries = 25 // Ref represents a node in the Firebase Realtime Database. type Ref struct { Key string Path string segs []string client *Client } // TransactionNode represents the value of a node within the scope of a transaction. type TransactionNode interface { Unmarshal(v interface{}) error } type transactionNodeImpl struct { Raw []byte } func (t *transactionNodeImpl) Unmarshal(v interface{}) error { return json.Unmarshal(t.Raw, v) } // Parent returns a reference to the parent of the current node. // // If the current reference points to the root of the database, Parent returns nil. func (r *Ref) Parent() *Ref { l := len(r.segs) if l > 0 { path := strings.Join(r.segs[:l-1], "/") return r.client.NewRef(path) } return nil } // Child returns a reference to the specified child node. func (r *Ref) Child(path string) *Ref { fp := fmt.Sprintf("%s/%s", r.Path, path) return r.client.NewRef(fp) } // Get retrieves the value at the current database location, and stores it in the value pointed to // by v. // // Data deserialization is performed using https://golang.org/pkg/encoding/json/#Unmarshal, and // therefore v has the same requirements as the json package. Specifically, it must be a pointer, // and must not be nil. func (r *Ref) Get(ctx context.Context, v interface{}) error { req := &internal.Request{ Method: http.MethodGet, } _, err := r.sendAndUnmarshal(ctx, req, v) return err } // GetWithETag retrieves the value at the current database location, along with its ETag. func (r *Ref) GetWithETag(ctx context.Context, v interface{}) (string, error) { req := &internal.Request{ Method: http.MethodGet, Opts: []internal.HTTPOption{ internal.WithHeader("X-Firebase-ETag", "true"), }, } resp, err := r.sendAndUnmarshal(ctx, req, v) if err != nil { return "", err } return resp.Header.Get("Etag"), nil } // GetShallow performs a shallow read on the current database location. // // Shallow reads do not retrieve the child nodes of the current reference. func (r *Ref) GetShallow(ctx context.Context, v interface{}) error { req := &internal.Request{ Method: http.MethodGet, Opts: []internal.HTTPOption{ internal.WithQueryParam("shallow", "true"), }, } _, err := r.sendAndUnmarshal(ctx, req, v) return err } // GetIfChanged retrieves the value and ETag of the current database location only if the specified // ETag does not match. // // If the specified ETag does not match, returns true along with the latest ETag of the database // location. The value of the database location will be stored in v just like a regular Get() call. // If the etag matches, returns false along with the same ETag passed into the function. No data // will be stored in v in this case. func (r *Ref) GetIfChanged(ctx context.Context, etag string, v interface{}) (bool, string, error) { req := &internal.Request{ Method: http.MethodGet, Opts: []internal.HTTPOption{ internal.WithHeader("If-None-Match", etag), }, SuccessFn: successOrNotModified, } resp, err := r.sendAndUnmarshal(ctx, req, nil) if err != nil { return false, "", err } if resp.Status == http.StatusNotModified { return false, etag, nil } if err := json.Unmarshal(resp.Body, v); err != nil { return false, "", err } return true, resp.Header.Get("ETag"), nil } // Set stores the value v in the current database node. // // Set uses https://golang.org/pkg/encoding/json/#Marshal to serialize values into JSON. Therefore // v has the same requirements as the json package. Values like functions and channels cannot be // saved into Realtime Database. func (r *Ref) Set(ctx context.Context, v interface{}) error { req := &internal.Request{ Method: http.MethodPut, Body: internal.NewJSONEntity(v), Opts: []internal.HTTPOption{ internal.WithQueryParam("print", "silent"), }, } _, err := r.sendAndUnmarshal(ctx, req, nil) return err } // SetIfUnchanged conditionally sets the data at this location to the given value. // // Sets the data at this location to v only if the specified ETag matches. Returns true if the // value is written. Returns false if no changes are made to the database. func (r *Ref) SetIfUnchanged(ctx context.Context, etag string, v interface{}) (bool, error) { req := &internal.Request{ Method: http.MethodPut, Body: internal.NewJSONEntity(v), Opts: []internal.HTTPOption{ internal.WithHeader("If-Match", etag), }, SuccessFn: successOrPreconditionFailed, } resp, err := r.sendAndUnmarshal(ctx, req, nil) if err != nil { return false, err } if resp.Status == http.StatusPreconditionFailed { return false, nil } return true, nil } // Push creates a new child node at the current location, and returns a reference to it. // // If v is not nil, it will be set as the initial value of the new child node. If v is nil, the // new child node will be created with empty string as the value. func (r *Ref) Push(ctx context.Context, v interface{}) (*Ref, error) { if v == nil { v = "" } req := &internal.Request{ Method: http.MethodPost, Body: internal.NewJSONEntity(v), } var d struct { Name string `json:"name"` } if _, err := r.sendAndUnmarshal(ctx, req, &d); err != nil { return nil, err } return r.Child(d.Name), nil } // Update modifies the specified child keys of the current location to the provided values. func (r *Ref) Update(ctx context.Context, v map[string]interface{}) error { if len(v) == 0 { return fmt.Errorf("value argument must be a non-empty map") } req := &internal.Request{ Method: http.MethodPatch, Body: internal.NewJSONEntity(v), Opts: []internal.HTTPOption{ internal.WithQueryParam("print", "silent"), }, } _, err := r.sendAndUnmarshal(ctx, req, nil) return err } // UpdateFn represents a function type that can be passed into Transaction(). type UpdateFn func(TransactionNode) (interface{}, error) // Transaction atomically modifies the data at this location. // // Unlike a normal Set(), which just overwrites the data regardless of its previous state, // Transaction() is used to modify the existing value to a new value, ensuring there are no // conflicts with other clients simultaneously writing to the same location. // // This is accomplished by passing an update function which is used to transform the current value // of this reference into a new value. If another client writes to this location before the new // value is successfully saved, the update function is called again with the new current value, and // the write will be retried. In case of repeated failures, this method will retry the transaction up // to 25 times before giving up and returning an error. // // The update function may also force an early abort by returning an error instead of returning a // value. func (r *Ref) Transaction(ctx context.Context, fn UpdateFn) error { req := &internal.Request{ Method: http.MethodGet, Opts: []internal.HTTPOption{ internal.WithHeader("X-Firebase-ETag", "true"), }, } resp, err := r.sendAndUnmarshal(ctx, req, nil) if err != nil { return err } etag := resp.Header.Get("Etag") for i := 0; i < txnRetries; i++ { new, err := fn(&transactionNodeImpl{resp.Body}) if err != nil { return err } req := &internal.Request{ Method: http.MethodPut, Body: internal.NewJSONEntity(new), Opts: []internal.HTTPOption{ internal.WithHeader("If-Match", etag), }, SuccessFn: successOrPreconditionFailed, } resp, err = r.sendAndUnmarshal(ctx, req, nil) if err != nil { return err } if resp.Status == http.StatusOK { return nil } etag = resp.Header.Get("ETag") } return fmt.Errorf("transaction aborted after failed retries") } // Delete removes this node from the database. func (r *Ref) Delete(ctx context.Context) error { req := &internal.Request{ Method: http.MethodDelete, } _, err := r.sendAndUnmarshal(ctx, req, nil) return err } func (r *Ref) sendAndUnmarshal( ctx context.Context, req *internal.Request, v interface{}) (*internal.Response, error) { req.URL = r.Path return r.client.sendAndUnmarshal(ctx, req, v) } func successOrNotModified(resp *internal.Response) bool { return internal.HasSuccessStatus(resp) || resp.Status == http.StatusNotModified } func successOrPreconditionFailed(resp *internal.Response) bool { return internal.HasSuccessStatus(resp) || resp.Status == http.StatusPreconditionFailed } golang-google-firebase-go-4.18.0/db/ref_test.go000066400000000000000000000467511505612111400212460ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "fmt" "net/http" "reflect" "testing" "firebase.google.com/go/v4/errorutils" ) type refOp func(r *Ref) error var testOps = []struct { name string resp interface{} op refOp }{ { "Get()", "test", func(r *Ref) error { var got string return r.Get(context.Background(), &got) }, }, { "GetWithETag()", "test", func(r *Ref) error { var got string _, err := r.GetWithETag(context.Background(), &got) return err }, }, { "GetShallow()", "test", func(r *Ref) error { var got string return r.GetShallow(context.Background(), &got) }, }, { "GetIfChanged()", "test", func(r *Ref) error { var got string _, _, err := r.GetIfChanged(context.Background(), "etag", &got) return err }, }, { "Set()", nil, func(r *Ref) error { return r.Set(context.Background(), "foo") }, }, { "SetIfUnchanged()", nil, func(r *Ref) error { _, err := r.SetIfUnchanged(context.Background(), "etag", "foo") return err }, }, { "Push()", map[string]interface{}{"name": "test"}, func(r *Ref) error { _, err := r.Push(context.Background(), "foo") return err }, }, { "Update()", nil, func(r *Ref) error { return r.Update(context.Background(), map[string]interface{}{"foo": "bar"}) }, }, { "Delete()", nil, func(r *Ref) error { return r.Delete(context.Background()) }, }, { "Transaction()", nil, func(r *Ref) error { fn := func(t TransactionNode) (interface{}, error) { var v interface{} if err := t.Unmarshal(&v); err != nil { return nil, err } return v, nil } return r.Transaction(context.Background(), fn) }, }, } func TestGet(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() cases := []interface{}{ nil, float64(1), true, "foo", map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, } var want []*testReq for _, tc := range cases { mock.Resp = tc var got interface{} if err := testref.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(tc, got) { t.Errorf("Get() = %v; want = %v", got, tc) } want = append(want, &testReq{Method: "GET", Path: "/peter.json"}) } checkAllRequests(t, mock.Reqs, want) } func TestInvalidGet(t *testing.T) { want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() got := func() {} if err := testref.Get(context.Background(), &got); err == nil { t.Errorf("Get(func) = nil; want error") } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } func TestGetWithStruct(t *testing.T) { want := person{Name: "Peter Parker", Age: 17} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() var got person if err := testref.Get(context.Background(), &got); err != nil { t.Fatal(err) } if want != got { t.Errorf("Get(struct) = %v; want = %v", got, want) } checkOnlyRequest(t, mock.Reqs, &testReq{Method: "GET", Path: "/peter.json"}) } func TestGetShallow(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() cases := []interface{}{ nil, float64(1), true, "foo", map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, map[string]interface{}{"name": "Peter Parker", "nestedChild": true}, } wantQuery := map[string]string{"shallow": "true"} var want []*testReq for _, tc := range cases { mock.Resp = tc var got interface{} if err := testref.GetShallow(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(tc, got) { t.Errorf("GetShallow() = %v; want = %v", got, tc) } want = append(want, &testReq{Method: "GET", Path: "/peter.json", Query: wantQuery}) } checkAllRequests(t, mock.Reqs, want) } func TestGetWithETag(t *testing.T) { want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{ Resp: want, Header: map[string]string{"ETag": "mock-etag"}, } srv := mock.Start(client) defer srv.Close() var got map[string]interface{} etag, err := testref.GetWithETag(context.Background(), &got) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("GetWithETag() = %v; want = %v", got, want) } if etag != "mock-etag" { t.Errorf("GetWithETag() = %q; want = %q", etag, "mock-etag") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }) } func TestGetIfChanged(t *testing.T) { want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{ Resp: want, Header: map[string]string{"ETag": "new-etag"}, } srv := mock.Start(client) defer srv.Close() var got map[string]interface{} ok, etag, err := testref.GetIfChanged(context.Background(), "old-etag", &got) if err != nil { t.Fatal(err) } if !ok { t.Errorf("GetIfChanged() = %v; want = %v", ok, true) } if !reflect.DeepEqual(want, got) { t.Errorf("GetIfChanged() = %v; want = %v", got, want) } if etag != "new-etag" { t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") } mock.Status = http.StatusNotModified mock.Resp = nil var got2 map[string]interface{} ok, etag, err = testref.GetIfChanged(context.Background(), "new-etag", &got2) if err != nil { t.Fatal(err) } if ok { t.Errorf("GetIfChanged() = %v; want = %v", ok, false) } if got2 != nil { t.Errorf("GetIfChanged() = %v; want nil", got2) } if etag != "new-etag" { t.Errorf("GetIfChanged() = %q; want = %q", etag, "new-etag") } checkAllRequests(t, mock.Reqs, []*testReq{ { Method: "GET", Path: "/peter.json", Header: http.Header{"If-None-Match": []string{"old-etag"}}, }, { Method: "GET", Path: "/peter.json", Header: http.Header{"If-None-Match": []string{"new-etag"}}, }, }) } func TestWelformedHttpError(t *testing.T) { mock := &mockServer{Resp: map[string]string{"error": "test error"}, Status: 500} srv := mock.Start(client) defer srv.Close() want := "http error status: 500; reason: test error" for _, tc := range testOps { t.Run(tc.name, func(t *testing.T) { err := tc.op(testref) if err == nil || err.Error() != want { t.Errorf("%s = %v; want = %v", tc.name, err, want) } if !errorutils.IsInternal(err) { t.Errorf("IsInternal(err) = false; want = true") } resp := errorutils.HTTPResponse(err) if resp == nil || resp.StatusCode != http.StatusInternalServerError { t.Errorf("HTTPResponse(err) = %v; want = {StatusCode: %d}", resp, http.StatusInternalServerError) } }) } wantReqs := len(testOps) if len(mock.Reqs) != wantReqs { t.Errorf("Requests = %d; want = %d", len(mock.Reqs), wantReqs) } } func TestUnexpectedHttpError(t *testing.T) { mock := &mockServer{Resp: "unexpected error", Status: 500} srv := mock.Start(client) defer srv.Close() want := "unexpected http response with status: 500\n\"unexpected error\"" for _, tc := range testOps { t.Run(tc.name, func(t *testing.T) { err := tc.op(testref) if err == nil || err.Error() != want { t.Errorf("%s = %v; want = %v", tc.name, err, want) } if !errorutils.IsInternal(err) { t.Errorf("IsInternal(err) = false; want = true") } resp := errorutils.HTTPResponse(err) if resp == nil || resp.StatusCode != http.StatusInternalServerError { t.Errorf("HTTPResponse(err) = %v; want = {StatusCode: %d}", resp, http.StatusInternalServerError) } }) } wantReqs := len(testOps) if len(mock.Reqs) != wantReqs { t.Errorf("Requests = %d; want = %d", len(mock.Reqs), wantReqs) } } func TestPlatformErrorCodes(t *testing.T) { mock := &mockServer{Resp: map[string]string{"error": "test error"}} srv := mock.Start(client) defer srv.Close() cases := []struct { name string status int check func(err error) bool }{ { name: "InvalidArgument", status: http.StatusBadRequest, check: errorutils.IsInvalidArgument, }, { name: "Unauthenticated", status: http.StatusUnauthorized, check: errorutils.IsUnauthenticated, }, { name: "NotFound", status: http.StatusNotFound, check: errorutils.IsNotFound, }, { name: "Internal", status: http.StatusInternalServerError, check: errorutils.IsInternal, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { mock.Status = tc.status want := fmt.Sprintf("http error status: %d; reason: test error", tc.status) err := testref.Delete(context.Background()) if err == nil || err.Error() != want { t.Errorf("Delete() = %v; want = %v", err, want) } if !tc.check(err) { t.Errorf("Is%s(err) = false; want = true", tc.name) } resp := errorutils.HTTPResponse(err) if resp == nil || resp.StatusCode != tc.status { t.Errorf("HTTPResponse(err) = %v; want = {StatusCode: %d}", resp, tc.status) } }) } } func TestInvalidPath(t *testing.T) { mock := &mockServer{Resp: "test"} srv := mock.Start(client) defer srv.Close() cases := []string{ "foo$", "foo.", "foo#", "foo]", "foo[", } for _, tc := range cases { r := client.NewRef(tc) for _, o := range testOps { err := o.op(r) if err == nil { t.Errorf("%s = nil; want = error", o.name) } } } if len(mock.Reqs) != 0 { t.Errorf("Requests = %v; want = empty", mock.Reqs) } } func TestInvalidChildPath(t *testing.T) { mock := &mockServer{Resp: "test"} srv := mock.Start(client) defer srv.Close() cases := []string{ "foo$", "foo.", "foo#", "foo]", "foo[", } for _, tc := range cases { r := testref.Child(tc) for _, o := range testOps { err := o.op(r) if err == nil { t.Errorf("%s = nil; want = error", o.name) } } } if len(mock.Reqs) != 0 { t.Errorf("Requests = %v; want = empty", mock.Reqs) } } func TestSet(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() cases := []interface{}{ 1, true, "foo", map[string]interface{}{"name": "Peter Parker", "age": float64(17)}, &person{"Peter Parker", 17}, } var want []*testReq for _, tc := range cases { if err := testref.Set(context.Background(), tc); err != nil { t.Fatal(err) } want = append(want, &testReq{ Method: "PUT", Path: "/peter.json", Body: serialize(tc), Query: map[string]string{"print": "silent"}, }) } checkAllRequests(t, mock.Reqs, want) } func TestInvalidSet(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() cases := []interface{}{ func() {}, make(chan int), } for _, tc := range cases { if err := testref.Set(context.Background(), tc); err == nil { t.Errorf("Set(%v) = nil; want = error", tc) } } if len(mock.Reqs) != 0 { t.Errorf("Set() = %v; want = empty", mock.Reqs) } } func TestSetIfUnchanged(t *testing.T) { mock := &mockServer{} srv := mock.Start(client) defer srv.Close() want := &person{"Peter Parker", 17} ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) if err != nil { t.Fatal(err) } if !ok { t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PUT", Path: "/peter.json", Body: serialize(want), Header: http.Header{"If-Match": []string{"mock-etag"}}, }) } func TestSetIfUnchangedError(t *testing.T) { mock := &mockServer{ Status: http.StatusPreconditionFailed, Resp: &person{"Tony Stark", 39}, } srv := mock.Start(client) defer srv.Close() want := &person{"Peter Parker", 17} ok, err := testref.SetIfUnchanged(context.Background(), "mock-etag", &want) if err != nil { t.Fatal(err) } if ok { t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PUT", Path: "/peter.json", Body: serialize(want), Header: http.Header{"If-Match": []string{"mock-etag"}}, }) } func TestPush(t *testing.T) { mock := &mockServer{Resp: map[string]string{"name": "new_key"}} srv := mock.Start(client) defer srv.Close() child, err := testref.Push(context.Background(), nil) if err != nil { t.Fatal(err) } if child.Key != "new_key" { t.Errorf("Push() = %q; want = %q", child.Key, "new_key") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "POST", Path: "/peter.json", Body: serialize(""), }) } func TestPushWithValue(t *testing.T) { mock := &mockServer{Resp: map[string]string{"name": "new_key"}} srv := mock.Start(client) defer srv.Close() want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} child, err := testref.Push(context.Background(), want) if err != nil { t.Fatal(err) } if child.Key != "new_key" { t.Errorf("Push() = %q; want = %q", child.Key, "new_key") } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "POST", Path: "/peter.json", Body: serialize(want), }) } func TestUpdate(t *testing.T) { want := map[string]interface{}{"name": "Peter Parker", "age": float64(17)} mock := &mockServer{Resp: want} srv := mock.Start(client) defer srv.Close() if err := testref.Update(context.Background(), want); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "PATCH", Path: "/peter.json", Body: serialize(want), Query: map[string]string{"print": "silent"}, }) } func TestInvalidUpdate(t *testing.T) { cases := []map[string]interface{}{ nil, make(map[string]interface{}), {"foo": func() {}}, } for _, tc := range cases { if err := testref.Update(context.Background(), tc); err == nil { t.Errorf("Update(%v) = nil; want error", tc) } } } func TestTransaction(t *testing.T) { mock := &mockServer{ Resp: &person{"Peter Parker", 17}, Header: map[string]string{"ETag": "mock-etag"}, } srv := mock.Start(client) defer srv.Close() var fn UpdateFn = func(t TransactionNode) (interface{}, error) { var p person if err := t.Unmarshal(&p); err != nil { return nil, err } p.Age++ return &p, nil } if err := testref.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } checkAllRequests(t, mock.Reqs, []*testReq{ { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ "name": "Peter Parker", "age": 18, }), Header: http.Header{"If-Match": []string{"mock-etag"}}, }, }) } func TestTransactionRetry(t *testing.T) { mock := &mockServer{ Resp: &person{"Peter Parker", 17}, Header: map[string]string{"ETag": "mock-etag1"}, } srv := mock.Start(client) defer srv.Close() cnt := 0 var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusPreconditionFailed mock.Header = map[string]string{"ETag": "mock-etag2"} mock.Resp = &person{"Peter Parker", 19} } else if cnt == 1 { mock.Status = http.StatusOK } cnt++ var p person if err := t.Unmarshal(&p); err != nil { return nil, err } p.Age++ return &p, nil } if err := testref.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } if cnt != 2 { t.Errorf("Transaction() retries = %d; want = %d", cnt, 2) } checkAllRequests(t, mock.Reqs, []*testReq{ { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ "name": "Peter Parker", "age": 18, }), Header: http.Header{"If-Match": []string{"mock-etag1"}}, }, { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ "name": "Peter Parker", "age": 20, }), Header: http.Header{"If-Match": []string{"mock-etag2"}}, }, }) } func TestTransactionError(t *testing.T) { mock := &mockServer{ Resp: &person{"Peter Parker", 17}, Header: map[string]string{"ETag": "mock-etag1"}, } srv := mock.Start(client) defer srv.Close() cnt := 0 want := "user error" var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusPreconditionFailed mock.Header = map[string]string{"ETag": "mock-etag2"} mock.Resp = &person{"Peter Parker", 19} } else if cnt == 1 { return nil, fmt.Errorf("%s", want) } cnt++ var p person if err := t.Unmarshal(&p); err != nil { return nil, err } p.Age++ return &p, nil } if err := testref.Transaction(context.Background(), fn); err == nil || err.Error() != want { t.Errorf("Transaction() = %v; want = %q", err, want) } if cnt != 1 { t.Errorf("Transaction() retries = %d; want = %d", cnt, 1) } checkAllRequests(t, mock.Reqs, []*testReq{ { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, { Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ "name": "Peter Parker", "age": 18, }), Header: http.Header{"If-Match": []string{"mock-etag1"}}, }, }) } func TestTransactionAbort(t *testing.T) { mock := &mockServer{ Resp: &person{"Peter Parker", 17}, Header: map[string]string{"ETag": "mock-etag1"}, } srv := mock.Start(client) defer srv.Close() cnt := 0 var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusPreconditionFailed mock.Header = map[string]string{"ETag": "mock-etag1"} } cnt++ var p person if err := t.Unmarshal(&p); err != nil { return nil, err } p.Age++ return &p, nil } err := testref.Transaction(context.Background(), fn) if err == nil { t.Errorf("Transaction() = nil; want error") } wanted := []*testReq{ { Method: "GET", Path: "/peter.json", Header: http.Header{"X-Firebase-ETag": []string{"true"}}, }, } for i := 0; i < txnRetries; i++ { wanted = append(wanted, &testReq{ Method: "PUT", Path: "/peter.json", Body: serialize(map[string]interface{}{ "name": "Peter Parker", "age": 18, }), Header: http.Header{"If-Match": []string{"mock-etag1"}}, }) } checkAllRequests(t, mock.Reqs, wanted) } func TestTransactionFailure(t *testing.T) { mock := &mockServer{ Resp: &person{"Peter Parker", 17}, Header: map[string]string{"ETag": "mock-etag1"}, } srv := mock.Start(client) defer srv.Close() cnt := 0 var fn UpdateFn = func(t TransactionNode) (interface{}, error) { if cnt == 0 { mock.Status = http.StatusInternalServerError mock.Resp = map[string]string{"error": "test error"} } cnt++ var p person if err := t.Unmarshal(&p); err != nil { return nil, err } p.Age++ return &p, nil } want := "http error status: 500; reason: test error" err := testref.Transaction(context.Background(), fn) if err == nil || err.Error() != want { t.Errorf("Transaction() = %v; want = %v", err, want) } if !errorutils.IsInternal(err) { t.Errorf("IsInternal() = false; want = true") } } func TestDelete(t *testing.T) { mock := &mockServer{Resp: "null"} srv := mock.Start(client) defer srv.Close() if err := testref.Delete(context.Background()); err != nil { t.Fatal(err) } checkOnlyRequest(t, mock.Reqs, &testReq{ Method: "DELETE", Path: "/peter.json", }) } golang-google-firebase-go-4.18.0/errorutils/000077500000000000000000000000001505612111400207145ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/errorutils/errorutils.go000066400000000000000000000126171505612111400234640ustar00rootroot00000000000000// Copyright 2020 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package errorutils provides functions for checking and handling error conditions. package errorutils import ( "net/http" "firebase.google.com/go/v4/internal" ) // IsInvalidArgument checks if the given error was due to an invalid client argument. func IsInvalidArgument(err error) bool { return internal.HasPlatformErrorCode(err, internal.InvalidArgument) } // IsFailedPrecondition checks if the given error was because a request could not be executed // in the current system state, such as deleting a non-empty directory. func IsFailedPrecondition(err error) bool { return internal.HasPlatformErrorCode(err, internal.FailedPrecondition) } // IsOutOfRange checks if the given error due to an invalid range specified by the client. func IsOutOfRange(err error) bool { return internal.HasPlatformErrorCode(err, internal.OutOfRange) } // IsUnauthenticated checks if the given error was caused by an unauthenticated request. // // Unauthenticated requests are due to missing, invalid, or expired OAuth token. func IsUnauthenticated(err error) bool { return internal.HasPlatformErrorCode(err, internal.Unauthenticated) } // IsPermissionDenied checks if the given error was due to a client not having suffificient // permissions. // // This can happen because the OAuth token does not have the right scopes, the client doesn't have // permission, or the API has not been enabled for the client project. func IsPermissionDenied(err error) bool { return internal.HasPlatformErrorCode(err, internal.PermissionDenied) } // IsNotFound checks if the given error was due to a specified resource being not found. // // This may also occur when the request is rejected by undisclosed reasons, such as whitelisting. func IsNotFound(err error) bool { return internal.HasPlatformErrorCode(err, internal.NotFound) } // IsConflict checks if the given error was due to a concurrency conflict, such as a // read-modify-write conflict. // // This represents an HTTP 409 Conflict status code, without additional information to distinguish // between ABORTED or ALREADY_EXISTS error conditions. func IsConflict(err error) bool { return internal.HasPlatformErrorCode(err, internal.Conflict) } // IsAborted checks if the given error was due to a concurrency conflict, such as a // read-modify-write conflict. func IsAborted(err error) bool { return internal.HasPlatformErrorCode(err, internal.Aborted) } // IsAlreadyExists checks if the given error was because a resource that a client tried to create // already exists. func IsAlreadyExists(err error) bool { return internal.HasPlatformErrorCode(err, internal.AlreadyExists) } // IsResourceExhausted checks if the given error was caused by either running out of a quota or // reaching a rate limit. func IsResourceExhausted(err error) bool { return internal.HasPlatformErrorCode(err, internal.ResourceExhausted) } // IsCancelled checks if the given error was due to the client cancelling a request. func IsCancelled(err error) bool { return internal.HasPlatformErrorCode(err, internal.Cancelled) } // IsDataLoss checks if the given error was due to an unrecoverable data loss or corruption. // // The client should report such errors to the end user. func IsDataLoss(err error) bool { return internal.HasPlatformErrorCode(err, internal.DataLoss) } // IsUnknown checks if the given error was cuased by an unknown server error. // // This typically indicates a server bug. func IsUnknown(err error) bool { return internal.HasPlatformErrorCode(err, internal.Unknown) } // IsInternal checks if the given error was due to an internal server error. // // This typically indicates a server bug. func IsInternal(err error) bool { return internal.HasPlatformErrorCode(err, internal.Internal) } // IsUnavailable checks if the given error was caused by an unavailable service. // // This typically indicates that the target service is temporarily down. func IsUnavailable(err error) bool { return internal.HasPlatformErrorCode(err, internal.Unavailable) } // IsDeadlineExceeded checks if the given error was due a request exceeding a deadline. // // This will happen only if the caller sets a deadline that is shorter than the method's default // deadline (i.e. requested deadline is not enough for the server to process the request) and the // request did not finish within the deadline. func IsDeadlineExceeded(err error) bool { return internal.HasPlatformErrorCode(err, internal.DeadlineExceeded) } // HTTPResponse returns the http.Response instance that caused the given error. // // If the error was not caused by an HTTP error response, returns nil. // // Returns a buffered copy of the original response received from the network stack. It is safe to // read the response content from the returned http.Response. func HTTPResponse(err error) *http.Response { fe, ok := err.(*internal.FirebaseError) if ok { return fe.Response } return nil } golang-google-firebase-go-4.18.0/firebase.go000066400000000000000000000165121505612111400206160ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package firebase is the entry point to the Firebase Admin SDK. It provides functionality for initializing App // instances, which serve as the central entities that provide access to various other Firebase services exposed // from the SDK. package firebase import ( "context" "encoding/json" "errors" "io/ioutil" "os" "cloud.google.com/go/firestore" "firebase.google.com/go/v4/appcheck" "firebase.google.com/go/v4/auth" "firebase.google.com/go/v4/db" "firebase.google.com/go/v4/iid" "firebase.google.com/go/v4/internal" "firebase.google.com/go/v4/messaging" "firebase.google.com/go/v4/remoteconfig" "firebase.google.com/go/v4/storage" "google.golang.org/api/option" "google.golang.org/api/transport" ) var defaultAuthOverrides = make(map[string]interface{}) // Version of the Firebase Go Admin SDK. const Version = "4.18.0" // firebaseEnvName is the name of the environment variable with the Config. const firebaseEnvName = "FIREBASE_CONFIG" // An App holds configuration and state common to all Firebase services that are exposed from the SDK. type App struct { authOverride map[string]interface{} dbURL string projectID string serviceAccountID string storageBucket string opts []option.ClientOption } // Config represents the configuration used to initialize an App. type Config struct { AuthOverride *map[string]interface{} `json:"databaseAuthVariableOverride"` DatabaseURL string `json:"databaseURL"` ProjectID string `json:"projectId"` ServiceAccountID string `json:"serviceAccountId"` StorageBucket string `json:"storageBucket"` } // Auth returns an instance of auth.Client. func (a *App) Auth(ctx context.Context) (*auth.Client, error) { conf := &internal.AuthConfig{ ProjectID: a.projectID, Opts: a.opts, ServiceAccountID: a.serviceAccountID, Version: Version, } return auth.NewClient(ctx, conf) } // Database returns an instance of db.Client to interact with the default Firebase Database // configured via Config.DatabaseURL. func (a *App) Database(ctx context.Context) (*db.Client, error) { return a.DatabaseWithURL(ctx, a.dbURL) } // DatabaseWithURL returns an instance of db.Client to interact with the Firebase Database // identified by the given URL. func (a *App) DatabaseWithURL(ctx context.Context, url string) (*db.Client, error) { conf := &internal.DatabaseConfig{ AuthOverride: a.authOverride, URL: url, Opts: a.opts, Version: Version, } return db.NewClient(ctx, conf) } // Storage returns a new instance of storage.Client. func (a *App) Storage(ctx context.Context) (*storage.Client, error) { conf := &internal.StorageConfig{ Opts: a.opts, Bucket: a.storageBucket, } return storage.NewClient(ctx, conf) } // Firestore returns a new firestore.Client instance from the https://godoc.org/cloud.google.com/go/firestore // package. func (a *App) Firestore(ctx context.Context) (*firestore.Client, error) { if a.projectID == "" { return nil, errors.New("project id is required to access Firestore") } return firestore.NewClient(ctx, a.projectID, a.opts...) } // InstanceID returns an instance of iid.Client. func (a *App) InstanceID(ctx context.Context) (*iid.Client, error) { conf := &internal.InstanceIDConfig{ ProjectID: a.projectID, Opts: a.opts, Version: Version, } return iid.NewClient(ctx, conf) } // Messaging returns an instance of messaging.Client. func (a *App) Messaging(ctx context.Context) (*messaging.Client, error) { conf := &internal.MessagingConfig{ ProjectID: a.projectID, Opts: a.opts, Version: Version, } return messaging.NewClient(ctx, conf) } // AppCheck returns an instance of appcheck.Client. func (a *App) AppCheck(ctx context.Context) (*appcheck.Client, error) { conf := &internal.AppCheckConfig{ ProjectID: a.projectID, } return appcheck.NewClient(ctx, conf) } // RemoteConfig returns an instance of remoteconfig.Client. func (a *App) RemoteConfig(ctx context.Context) (*remoteconfig.Client, error) { conf := &internal.RemoteConfigClientConfig{ ProjectID: a.projectID, Opts: a.opts, Version: Version, } return remoteconfig.NewClient(ctx, conf) } // NewApp creates a new App from the provided config and client options. // // If the client options contain a valid credential (a service account file, a refresh token // file or an oauth2.TokenSource) the App will be authenticated using that credential. Otherwise, // NewApp attempts to authenticate the App with Google application default credentials. // If `config` is nil, the SDK will attempt to load the config options from the // `FIREBASE_CONFIG` environment variable. If the value in it starts with a `{` it is parsed as a // JSON object, otherwise it is assumed to be the name of the JSON file containing the options. func NewApp(ctx context.Context, config *Config, opts ...option.ClientOption) (*App, error) { o := []option.ClientOption{option.WithScopes(internal.FirebaseScopes...)} o = append(o, opts...) if config == nil { var err error if config, err = getConfigDefaults(); err != nil { return nil, err } } pid := getProjectID(ctx, config, o...) ao := defaultAuthOverrides if config.AuthOverride != nil { ao = *config.AuthOverride } return &App{ authOverride: ao, dbURL: config.DatabaseURL, projectID: pid, serviceAccountID: config.ServiceAccountID, storageBucket: config.StorageBucket, opts: o, }, nil } // getConfigDefaults reads the default config file, defined by the FIREBASE_CONFIG // env variable, used only when options are nil. func getConfigDefaults() (*Config, error) { fbc := &Config{} confFileName := os.Getenv(firebaseEnvName) if confFileName == "" { return fbc, nil } var dat []byte if confFileName[0] == byte('{') { dat = []byte(confFileName) } else { var err error if dat, err = ioutil.ReadFile(confFileName); err != nil { return nil, err } } if err := json.Unmarshal(dat, fbc); err != nil { return nil, err } // Some special handling necessary for db auth overrides var m map[string]interface{} if err := json.Unmarshal(dat, &m); err != nil { return nil, err } if ao, ok := m["databaseAuthVariableOverride"]; ok && ao == nil { // Auth overrides are explicitly set to null var nullMap map[string]interface{} fbc.AuthOverride = &nullMap } return fbc, nil } func getProjectID(ctx context.Context, config *Config, opts ...option.ClientOption) string { if config.ProjectID != "" { return config.ProjectID } creds, _ := transport.Creds(ctx, opts...) if creds != nil && creds.ProjectID != "" { return creds.ProjectID } if pid := os.Getenv("GOOGLE_CLOUD_PROJECT"); pid != "" { return pid } return os.Getenv("GCLOUD_PROJECT") } golang-google-firebase-go-4.18.0/firebase_test.go000066400000000000000000000445411505612111400216600ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package firebase import ( "context" "encoding/json" "fmt" "io/ioutil" "log" "net/http" "net/http/httptest" "os" "reflect" "strconv" "strings" "testing" "time" "firebase.google.com/go/v4/messaging" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/option" "google.golang.org/api/transport" ) const credEnvVar = "GOOGLE_APPLICATION_CREDENTIALS" func TestMain(m *testing.M) { // This isolates the tests from a possiblity that the default config env // variable is set to a valid file containing the wanted default config, // but we the test is not expecting it. configOld := overwriteEnv(firebaseEnvName, "") defer reinstateEnv(firebaseEnvName, configOld) os.Exit(m.Run()) } func TestServiceAcctFile(t *testing.T) { app, err := NewApp(context.Background(), nil, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if app.projectID != "mock-project-id" { t.Errorf("Project ID: %q; want: %q", app.projectID, "mock-project-id") } if len(app.opts) != 2 { t.Errorf("Client opts: %d; want: 2", len(app.opts)) } } func TestClientOptions(t *testing.T) { ts := initMockTokenServer() defer ts.Close() b, err := mockServiceAcct(ts.URL) if err != nil { t.Fatal(err) } config, err := google.JWTConfigFromJSON(b) if err != nil { t.Fatal(err) } ctx := context.Background() app, err := NewApp(ctx, nil, option.WithTokenSource(config.TokenSource(ctx))) if err != nil { t.Fatal(err) } var bearer string service := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bearer = r.Header.Get("Authorization") w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"output": "test"}`)) })) defer service.Close() client, _, err := transport.NewHTTPClient(ctx, app.opts...) if err != nil { t.Fatal(err) } resp, err := client.Get(service.URL) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Status: %d; want: %d", resp.StatusCode, http.StatusOK) } if bearer != "Bearer mock-token" { t.Errorf("Bearer token: %q; want: %q", bearer, "Bearer mock-token") } } func TestRefreshTokenFile(t *testing.T) { app, err := NewApp(context.Background(), nil, option.WithCredentialsFile("testdata/refresh_token.json")) if err != nil { t.Fatal(err) } if len(app.opts) != 2 { t.Errorf("Client opts: %d; want: 2", len(app.opts)) } } func TestRefreshTokenFileWithConfig(t *testing.T) { config := &Config{ProjectID: "mock-project-id"} app, err := NewApp(context.Background(), config, option.WithCredentialsFile("testdata/refresh_token.json")) if err != nil { t.Fatal(err) } if app.projectID != "mock-project-id" { t.Errorf("Project ID: %q; want: mock-project-id", app.projectID) } if len(app.opts) != 2 { t.Errorf("Client opts: %d; want: 2", len(app.opts)) } } func TestRefreshTokenWithEnvVar(t *testing.T) { verify := func(varName string) { current := os.Getenv(varName) if err := os.Setenv(varName, "mock-project-id"); err != nil { t.Fatal(err) } defer os.Setenv(varName, current) app, err := NewApp(context.Background(), nil, option.WithCredentialsFile("testdata/refresh_token.json")) if err != nil { t.Fatal(err) } if app.projectID != "mock-project-id" { t.Errorf("[env=%s] Project ID: %q; want: mock-project-id", varName, app.projectID) } } for _, varName := range []string{"GCLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"} { verify(varName) } } func TestAppDefault(t *testing.T) { current := os.Getenv(credEnvVar) if err := os.Setenv(credEnvVar, "testdata/service_account.json"); err != nil { t.Fatal(err) } defer os.Setenv(credEnvVar, current) app, err := NewApp(context.Background(), nil) if err != nil { t.Fatal(err) } if len(app.opts) != 1 { t.Errorf("Client opts: %d; want: 1", len(app.opts)) } } func TestAppDefaultWithInvalidFile(t *testing.T) { current := os.Getenv(credEnvVar) if err := os.Setenv(credEnvVar, "testdata/non_existing.json"); err != nil { t.Fatal(err) } defer os.Setenv(credEnvVar, current) app, err := NewApp(context.Background(), nil) if app == nil || err != nil { t.Fatalf("NewApp() = (%v, %v); want = (app, nil)", app, err) } } func TestInvalidCredentialFile(t *testing.T) { invalidFiles := []string{ "testdata", "testdata/plain_text.txt", } ctx := context.Background() for _, tc := range invalidFiles { app, err := NewApp(ctx, nil, option.WithCredentialsFile(tc)) if app == nil || err != nil { t.Fatalf("NewApp() = (%v, %v); want = (app, nil)", app, err) } } } func TestExplicitNoAuth(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithoutAuthentication()) if app == nil || err != nil { t.Fatalf("NewApp() = (%v, %v); want = (app, nil)", app, err) } } func TestAuth(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if c, err := app.Auth(ctx); c == nil || err != nil { t.Errorf("Auth() = (%v, %v); want (auth, nil)", c, err) } } func TestDatabase(t *testing.T) { ctx := context.Background() conf := &Config{DatabaseURL: "https://mock-db.firebaseio.com"} app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if app.authOverride == nil || len(app.authOverride) != 0 { t.Errorf("AuthOverrides = %v; want = empty map", app.authOverride) } if c, err := app.Database(ctx); c == nil || err != nil { t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) } url := "https://other-mock-db.firebaseio.com" if c, err := app.DatabaseWithURL(ctx, url); c == nil || err != nil { t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) } } func TestDatabaseAuthOverrides(t *testing.T) { cases := []map[string]interface{}{ nil, {}, {"uid": "user1"}, } for _, tc := range cases { ctx := context.Background() conf := &Config{ AuthOverride: &tc, DatabaseURL: "https://mock-db.firebaseio.com", } app, err := NewApp(ctx, conf, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(app.authOverride, tc) { t.Errorf("AuthOverrides = %v; want = %v", app.authOverride, tc) } if c, err := app.Database(ctx); c == nil || err != nil { t.Errorf("Database() = (%v, %v); want (db, nil)", c, err) } } } func TestStorage(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if c, err := app.Storage(ctx); c == nil || err != nil { t.Errorf("Storage() = (%v, %v); want (auth, nil)", c, err) } } func TestFirestore(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if c, err := app.Firestore(ctx); c == nil || err != nil { t.Errorf("Firestore() = (%v, %v); want (auth, nil)", c, err) } } func TestFirestoreWithProjectID(t *testing.T) { verify := func(varName string) { current := os.Getenv(varName) if err := os.Setenv(varName, ""); err != nil { t.Fatal(err) } defer os.Setenv(varName, current) ctx := context.Background() config := &Config{ProjectID: "project-id"} app, err := NewApp(ctx, config, option.WithCredentialsFile("testdata/refresh_token.json")) if err != nil { t.Fatal(err) } if c, err := app.Firestore(ctx); c == nil || err != nil { t.Errorf("[env=%s] Firestore() = (%v, %v); want (auth, nil)", varName, c, err) } } for _, varName := range []string{"GCLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"} { verify(varName) } } func TestFirestoreWithNoProjectID(t *testing.T) { unsetVariable := func(varName string) string { current := os.Getenv(varName) if err := os.Setenv(varName, ""); err != nil { t.Fatal(err) } return current } for _, varName := range []string{"GCLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"} { if current := unsetVariable(varName); current != "" { defer os.Setenv(varName, current) } } ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/refresh_token.json")) if err != nil { t.Fatal(err) } if c, err := app.Firestore(ctx); c != nil || err == nil { t.Errorf("Firestore() = (%v, %v); want (nil, error)", c, err) } } func TestInstanceID(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if c, err := app.InstanceID(ctx); c == nil || err != nil { t.Errorf("InstanceID() = (%v, %v); want (iid, nil)", c, err) } } func TestMessaging(t *testing.T) { ctx := context.Background() app, err := NewApp(ctx, nil, option.WithCredentialsFile("testdata/service_account.json")) if err != nil { t.Fatal(err) } if c, err := app.Messaging(ctx); c == nil || err != nil { t.Errorf("Messaging() = (%v, %v); want (iid, nil)", c, err) } } func TestMessagingSendWithCustomEndpoint(t *testing.T) { name := "custom-endpoint-ok" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte("{ \"name\":\"" + name + "\" }")) })) defer ts.Close() ctx := context.Background() tokenSource := &testTokenSource{AccessToken: "mock-token-from-custom"} app, err := NewApp( ctx, &Config{ProjectID: "test-project-id"}, option.WithTokenSource(tokenSource), option.WithEndpoint(ts.URL), ) if err != nil { t.Fatal(err) } c, err := app.Messaging(ctx) if c == nil || err != nil { t.Fatalf("Messaging() = (%v, %v); want (iid, nil)", c, err) } msg := &messaging.Message{ Token: "token", } n, err := c.Send(ctx, msg) if n != name || err != nil { t.Errorf("Send() = (%q, %v); want (%q, nil)", n, err, name) } } func TestCustomTokenSource(t *testing.T) { ctx := context.Background() ts := &testTokenSource{AccessToken: "mock-token-from-custom"} app, err := NewApp(ctx, nil, option.WithTokenSource(ts)) if err != nil { t.Fatal(err) } client, _, err := transport.NewHTTPClient(ctx, app.opts...) if err != nil { t.Fatal(err) } var bearer string service := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { bearer = r.Header.Get("Authorization") w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"output": "test"}`)) })) defer service.Close() resp, err := client.Get(service.URL) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Errorf("Status: %d; want: %d", resp.StatusCode, http.StatusOK) } if bearer != "Bearer "+ts.AccessToken { t.Errorf("Bearer token: %q; want: %q", bearer, "Bearer "+ts.AccessToken) } } func TestVersion(t *testing.T) { segments := strings.Split(Version, ".") if len(segments) != 3 { t.Errorf("Incorrect number of segments: %d; want: 3", len(segments)) } for _, segment := range segments { if _, err := strconv.Atoi(segment); err != nil { t.Errorf("Invalid segment in version number: %q; want integer", segment) } } } func TestAutoInit(t *testing.T) { var nullMap map[string]interface{} uidMap := map[string]interface{}{"uid": "test"} tests := []struct { name string optionsConfig string initOptions *Config wantOptions *Config }{ { "", "", nil, &Config{ProjectID: "mock-project-id"}, // from default creds here and below. }, { "", "testdata/firebase_config.json", nil, &Config{ DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, }, { "", `{ "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", StorageBucket: "auto-init.storage.bucket", }, }, { "", "testdata/firebase_config_partial.json", nil, &Config{ProjectID: "auto-init-project-id"}, }, { "", `{"projectId": "auto-init-project-id"}`, nil, &Config{ProjectID: "auto-init-project-id"}, }, { "", "testdata/firebase_config_partial.json", &Config{StorageBucket: "sb1-mock"}, &Config{ ProjectID: "mock-project-id", StorageBucket: "sb1-mock", }, }, { "", `{"projectId": "auto-init-project-id"}`, &Config{StorageBucket: "sb1-mock"}, &Config{ ProjectID: "mock-project-id", // from default creds StorageBucket: "sb1-mock", }, }, { "", "testdata/firebase_config_partial.json", &Config{}, &Config{ProjectID: "mock-project-id"}, }, { "", `{"projectId": "auto-init-project-id"}`, &Config{}, &Config{ProjectID: "mock-project-id"}, }, { "", "testdata/firebase_config_invalid_key.json", nil, &Config{ ProjectID: "mock-project-id", // from default creds StorageBucket: "auto-init.storage.bucket", }, }, { "", `{ "obviously_bad_key": "mock-project-id", "storageBucket": "auto-init.storage.bucket" }`, nil, &Config{ ProjectID: "mock-project-id", StorageBucket: "auto-init.storage.bucket", }, }, { "", `{ "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "databaseAuthVariableOverride": null }`, nil, &Config{ DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", AuthOverride: &nullMap, }, }, { "", `{ "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "databaseAuthVariableOverride": {"uid": "test"} }`, nil, &Config{ DatabaseURL: "https://auto-init.database.url", ProjectID: "auto-init-project-id", AuthOverride: &uidMap, }, }, } credOld := overwriteEnv(credEnvVar, "testdata/service_account.json") defer reinstateEnv(credEnvVar, credOld) for _, test := range tests { t.Run(fmt.Sprintf("NewApp(%s)", test.name), func(t *testing.T) { overwriteEnv(firebaseEnvName, test.optionsConfig) app, err := NewApp(context.Background(), test.initOptions) if err != nil { t.Error(err) } else { compareConfig(app, test.wantOptions, t) } }) } } func TestAutoInitInvalidFiles(t *testing.T) { tests := []struct { name string filename string wantError string }{ { "NonexistingFile", "testdata/no_such_file.json", "open testdata/no_such_file.json: no such file or directory", }, { "InvalidJSON", "testdata/firebase_config_invalid.json", "invalid character 'b' looking for beginning of value", }, { "EmptyFile", "testdata/firebase_config_empty.json", "unexpected end of JSON input", }, } credOld := overwriteEnv(credEnvVar, "testdata/service_account.json") defer reinstateEnv(credEnvVar, credOld) for _, test := range tests { t.Run(test.name, func(t *testing.T) { overwriteEnv(firebaseEnvName, test.filename) _, err := NewApp(context.Background(), nil) if err == nil || err.Error() != test.wantError { t.Errorf("%s got error = %s; want = %s", test.name, err, test.wantError) } }) } } type testTokenSource struct { AccessToken string Expiry time.Time } func (t *testTokenSource) Token() (*oauth2.Token, error) { return &oauth2.Token{ AccessToken: t.AccessToken, Expiry: t.Expiry, }, nil } func compareConfig(got *App, want *Config, t *testing.T) { if got.dbURL != want.DatabaseURL { t.Errorf("app.dbURL = %q; want = %q", got.dbURL, want.DatabaseURL) } if want.AuthOverride != nil { if !reflect.DeepEqual(got.authOverride, *want.AuthOverride) { t.Errorf("app.ao = %#v; want = %#v", got.authOverride, *want.AuthOverride) } } else if !reflect.DeepEqual(got.authOverride, defaultAuthOverrides) { t.Errorf("app.ao = %#v; want = nil", got.authOverride) } if got.projectID != want.ProjectID { t.Errorf("app.projectID = %q; want = %q", got.projectID, want.ProjectID) } if got.storageBucket != want.StorageBucket { t.Errorf("app.storageBucket = %q; want = %q", got.storageBucket, want.StorageBucket) } } // mockServiceAcct generates a service account configuration with the provided URL as the // token_url value. func mockServiceAcct(tokenURL string) ([]byte, error) { b, err := ioutil.ReadFile("testdata/service_account.json") if err != nil { return nil, err } var parsed map[string]interface{} if err := json.Unmarshal(b, &parsed); err != nil { return nil, err } parsed["token_uri"] = tokenURL return json.Marshal(parsed) } // initMockTokenServer starts a mock HTTP server that Apps can invoke during tests to obtain // OAuth2 access tokens. func initMockTokenServer() *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{ "access_token": "mock-token", "scope": "user", "token_type": "bearer", "expires_in": 3600 }`)) })) } // overwriteEnv overwrites env variables, used in testsing. func overwriteEnv(varName, newVal string) string { oldVal := os.Getenv(varName) if newVal == "" { if err := os.Unsetenv(varName); err != nil { log.Fatal(err) } } else if err := os.Setenv(varName, newVal); err != nil { log.Fatal(err) } return oldVal } // reinstateEnv restores the environment variable, will usually be used deferred with overwriteEnv. func reinstateEnv(varName, oldVal string) { if len(varName) > 0 { os.Setenv(varName, oldVal) } else { os.Unsetenv(varName) } } golang-google-firebase-go-4.18.0/go.mod000066400000000000000000000057331505612111400176200ustar00rootroot00000000000000module firebase.google.com/go/v4 go 1.23.0 require ( cloud.google.com/go/firestore v1.18.0 cloud.google.com/go/storage v1.53.0 github.com/MicahParks/keyfunc v1.9.0 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/go-cmp v0.7.0 golang.org/x/oauth2 v0.30.0 google.golang.org/api v0.231.0 google.golang.org/appengine/v2 v2.0.6 ) require ( cel.dev/expr v0.23.1 // indirect cloud.google.com/go v0.121.0 // indirect cloud.google.com/go/auth v0.16.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect cloud.google.com/go/iam v1.5.2 // indirect cloud.google.com/go/longrunning v0.6.7 // indirect cloud.google.com/go/monitoring v1.24.2 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 // indirect github.com/envoyproxy/go-control-plane/envoy v1.32.4 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-jose/go-jose/v4 v4.0.5 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect github.com/zeebo/errs v1.4.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.35.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect golang.org/x/crypto v0.40.0 // indirect golang.org/x/net v0.42.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.34.0 // indirect golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.11.0 // indirect google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250505200425-f936aa4a68b2 // indirect google.golang.org/grpc v1.72.0 // indirect google.golang.org/protobuf v1.36.6 // indirect ) golang-google-firebase-go-4.18.0/go.sum000066400000000000000000000363051505612111400176440ustar00rootroot00000000000000cel.dev/expr v0.23.1 h1:K4KOtPCJQjVggkARsjG9RWXP6O4R73aHeJMa/dmCQQg= cel.dev/expr v0.23.1/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.121.0 h1:pgfwva8nGw7vivjZiRfrmglGWiCJBP+0OmDpenG/Fwg= cloud.google.com/go v0.121.0/go.mod h1:rS7Kytwheu/y9buoDmu5EIpMMCI4Mb8ND4aeN4Vwj7Q= cloud.google.com/go/auth v0.16.1 h1:XrXauHMd30LhQYVRHLGvJiYeczweKQXZxsTbV9TiguU= cloud.google.com/go/auth v0.16.1/go.mod h1:1howDHJ5IETh/LwYs3ZxvlkXF48aSqqJUM+5o02dNOI= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/firestore v1.18.0 h1:cuydCaLS7Vl2SatAeivXyhbhDEIR8BDmtn4egDhIn2s= cloud.google.com/go/firestore v1.18.0/go.mod h1:5ye0v48PhseZBdcl0qbl3uttu7FIEwEYVaWm0UIEOEU= cloud.google.com/go/iam v1.5.2 h1:qgFRAGEmd8z6dJ/qyEchAuL9jpswyODjA2lS+w234g8= cloud.google.com/go/iam v1.5.2/go.mod h1:SE1vg0N81zQqLzQEwxL2WI6yhetBdbNQuTvIKCSkUHE= cloud.google.com/go/logging v1.13.0 h1:7j0HgAp0B94o1YRDqiqm26w4q1rDMH7XNRU34lJXHYc= cloud.google.com/go/logging v1.13.0/go.mod h1:36CoKh6KA/M0PbhPKMq6/qety2DCAErbhXT62TuXALA= cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE= cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY= cloud.google.com/go/monitoring v1.24.2 h1:5OTsoJ1dXYIiMiuL+sYscLc9BumrL3CarVLL7dd7lHM= cloud.google.com/go/monitoring v1.24.2/go.mod h1:x7yzPWcgDRnPEv3sI+jJGBkwl5qINf+6qY4eq0I9B4U= cloud.google.com/go/storage v1.53.0 h1:gg0ERZwL17pJ+Cz3cD2qS60w1WMDnwcm5YPAIQBHUAw= cloud.google.com/go/storage v1.53.0/go.mod h1:7/eO2a/srr9ImZW9k5uufcNahT2+fPb8w5it1i5boaA= cloud.google.com/go/trace v1.11.6 h1:2O2zjPzqPYAHrn3OKl029qlqG6W8ZdYaOWRyr8NgMT4= cloud.google.com/go/trace v1.11.6/go.mod h1:GA855OeDEBiBMzcckLPE2kDunIpC72N+Pq8WFieFjnI= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 h1:ErKg/3iS1AKcTkf3yixlZ54f9U1rljCkQyEXWUnIUxc= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0/go.mod h1:yAZHSGnqScoU556rBOVkwLze6WP5N+U11RHuWaGVxwY= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 h1:fYE9p3esPxA/C0rQ0AHhP0drtPXDRhaWiwg1DPqO7IU= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0/go.mod h1:BnBReJLvVYx2CS/UHOgVz2BXKXD9wsQPxZug20nZhd0= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.51.0 h1:OqVGm6Ei3x5+yZmSJG1Mh2NwHvpVmZ08CB5qJhT9Nuk= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.51.0/go.mod h1:SZiPHWGOOk3bl8tkevxkoiwPgsIl6CwrWcbwjfHZpdM= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 h1:6/0iUd0xrnX7qt+mLNRwg5c0PGv8wpE8K90ryANQwMI= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0/go.mod h1:otE2jQekW/PqXk1Awf5lmfokJx4uwuqcj1ab5SpGeW0= github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID3+o= github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls= github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= github.com/envoyproxy/go-control-plane v0.13.4/go.mod h1:kDfuBlDVsSj2MjrLEtRWtHlsWIFcGyB2RMO44Dc5GZA= github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE= github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/spiffe/go-spiffe/v2 v2.5.0 h1:N2I01KCUkv1FAjZXJMwh95KK1ZIQLYbPfhaxw8WS0hE= github.com/spiffe/go-spiffe/v2 v2.5.0/go.mod h1:P+NxobPc6wXhVtINNtFjNWGBTreew1GBUCwT2wPmb7g= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/detectors/gcp v1.35.0 h1:bGvFt68+KTiAKFlacHW6AhA56GF2rS0bdD3aJYEnmzA= go.opentelemetry.io/contrib/detectors/gcp v1.35.0/go.mod h1:qGWP8/+ILwMRIUf9uIVLloR1uo5ZYAslM4O6OqUi1DA= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 h1:x7wzEgXfnzJcHDwStJT+mxOz4etr2EcexjqhBvmoakw= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0/go.mod h1:rg+RlpR5dKwaS95IyyZqj5Wd4E13lk/msnTS0Xl9lJM= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ= go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.35.0 h1:PB3Zrjs1sG1GBX51SXyTSoOTqcDglmsk7nT6tkKPb/k= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.35.0/go.mod h1:U2R3XyVPzn0WX7wOIypPuptulsMcPDPs/oiSVOMVnHY= go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.231.0 h1:LbUD5FUl0C4qwia2bjXhCMH65yz1MLPzA/0OYEsYY7Q= google.golang.org/api v0.231.0/go.mod h1:H52180fPI/QQlUc0F4xWfGZILdv09GCWKt2bcsn164A= google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2 h1:1tXaIXCracvtsRxSBsYDiSBN0cuJvM7QYW+MrpIRY78= google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2/go.mod h1:49MsLSx0oWMOZqcpB3uL8ZOkAh1+TndpJ8ONoCBWiZk= google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2 h1:vPV0tzlsK6EzEDHNNH5sa7Hs9bd7iXR7B1tSiPepkV0= google.golang.org/genproto/googleapis/api v0.0.0-20250505200425-f936aa4a68b2/go.mod h1:pKLAc5OolXC3ViWGI62vvC0n10CpwAtRcTNCFwTKBEw= google.golang.org/genproto/googleapis/rpc v0.0.0-20250505200425-f936aa4a68b2 h1:IqsN8hx+lWLqlN+Sc3DoMy/watjofWiU8sRFgQ8fhKM= google.golang.org/genproto/googleapis/rpc v0.0.0-20250505200425-f936aa4a68b2/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= golang-google-firebase-go-4.18.0/iid/000077500000000000000000000000001505612111400172475ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/iid/iid.go000066400000000000000000000130531505612111400203450ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package iid contains functions for deleting instance IDs from Firebase projects. package iid import ( "context" "errors" "fmt" "net/http" "strings" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" ) const iidEndpoint = "https://console.firebase.google.com/v1" var errorMessages = map[int]string{ http.StatusBadRequest: "malformed instance id argument", http.StatusUnauthorized: "request not authorized", http.StatusForbidden: "project does not match instance ID or the client does not have sufficient privileges", http.StatusNotFound: "failed to find the instance id", http.StatusConflict: "already deleted", http.StatusTooManyRequests: "request throttled out by the backend server", http.StatusInternalServerError: "internal server error", http.StatusServiceUnavailable: "backend servers are over capacity", } // IsInvalidArgument checks if the given error was due to an invalid instance ID argument. // // Deprecated: Use errorutils.IsInvalidArgument() function instead. func IsInvalidArgument(err error) bool { return errorutils.IsInvalidArgument(err) } // IsInsufficientPermission checks if the given error was due to the request not having the // required authorization. This could be due to the client not having the required permission // or the specified instance ID not matching the target Firebase project. // // Deprecated: Use errorutils.IsUnauthenticated() or errorutils.IsPermissionDenied() instead. func IsInsufficientPermission(err error) bool { return errorutils.IsUnauthenticated(err) || errorutils.IsPermissionDenied(err) } // IsNotFound checks if the given error was due to a non existing instance ID. func IsNotFound(err error) bool { return errorutils.IsNotFound(err) } // IsAlreadyDeleted checks if the given error was due to the instance ID being already deleted from // the project. // // Deprecated: Use errorutils.IsConflict() function instead. func IsAlreadyDeleted(err error) bool { return errorutils.IsConflict(err) } // IsTooManyRequests checks if the given error was due to the client sending too many requests // causing a server quota to exceed. // // Deprecated: Use errorutils.IsResourceExhausted() function instead. func IsTooManyRequests(err error) bool { return errorutils.IsResourceExhausted(err) } // IsInternal checks if the given error was due to an internal server error. // // Deprecated: Use errorutils.IsInternal() function instead. func IsInternal(err error) bool { return errorutils.IsInternal(err) } // IsServerUnavailable checks if the given error was due to the backend server being temporarily // unavailable. // // Deprecated: Use errorutils.IsUnavailable() function instead. func IsServerUnavailable(err error) bool { return errorutils.IsUnavailable(err) } // IsUnknown checks if the given error was due to unknown error returned by the backend server. // // Deprecated: Use errorutils.IsUnknown() function instead. func IsUnknown(err error) bool { return errorutils.IsUnknown(err) } // Client is the interface for the Firebase Instance ID service. type Client struct { // To enable testing against arbitrary endpoints. endpoint string client *internal.HTTPClient project string } // NewClient creates a new instance of the Firebase instance ID Client. // // This function can only be invoked from within the SDK. Client applications should access the // the instance ID service through firebase.App. func NewClient(ctx context.Context, c *internal.InstanceIDConfig) (*Client, error) { if c.ProjectID == "" { return nil, errors.New("project id is required to access instance id client") } hc, _, err := internal.NewHTTPClient(ctx, c.Opts...) if err != nil { return nil, err } hc.Opts = []internal.HTTPOption{ internal.WithHeader("x-goog-api-client", internal.GetMetricsHeader(c.Version)), } hc.CreateErrFn = createError return &Client{ endpoint: iidEndpoint, client: hc, project: c.ProjectID, }, nil } // DeleteInstanceID deletes the specified instance ID and the associated data from Firebase.. // // Note that Google Analytics for Firebase uses its own form of Instance ID to keep track of // analytics data. Therefore deleting a regular instance ID does not delete Analytics data. // See https://firebase.google.com/support/privacy/manage-iids#delete_an_instance_id for // more information. func (c *Client) DeleteInstanceID(ctx context.Context, iid string) error { if iid == "" { return errors.New("instance id must not be empty") } url := fmt.Sprintf("%s/project/%s/instanceId/%s", c.endpoint, c.project, iid) _, err := c.client.Do(ctx, &internal.Request{Method: http.MethodDelete, URL: url}) return err } func createError(resp *internal.Response) error { err := internal.NewFirebaseError(resp) if msg, ok := errorMessages[resp.Status]; ok { requestPath := resp.LowLevelResponse().Request.URL.Path idx := strings.LastIndex(requestPath, "/") err.String = fmt.Sprintf("instance id %q: %s", requestPath[idx+1:], msg) } return err } golang-google-firebase-go-4.18.0/iid/iid_test.go000066400000000000000000000166441505612111400214150ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package iid import ( "context" "fmt" "net/http" "net/http/httptest" "testing" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" ) var testIIDConfig = &internal.InstanceIDConfig{ ProjectID: "test-project", Opts: []option.ClientOption{ option.WithTokenSource(&internal.MockTokenSource{AccessToken: "test-token"}), }, Version: "test-version", } func TestNoProjectID(t *testing.T) { client, err := NewClient(context.Background(), &internal.InstanceIDConfig{}) if client != nil || err == nil { t.Errorf("NewClient() = (%v, %v); want = (nil, error)", client, err) } } func TestInvalidInstanceID(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testIIDConfig) if err != nil { t.Fatal(err) } if err := client.DeleteInstanceID(ctx, ""); err == nil { t.Errorf("DeleteInstanceID(empty) = nil; want error") } } func TestDeleteInstanceID(t *testing.T) { var tr *http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testIIDConfig) if err != nil { t.Fatal(err) } client.endpoint = ts.URL if err := client.DeleteInstanceID(ctx, "test-iid"); err != nil { t.Errorf("DeleteInstanceID() = %v; want nil", err) } if tr == nil { t.Fatalf("Request = nil; want non-nil") } if tr.Method != http.MethodDelete { t.Errorf("Method = %q; want = %q", tr.Method, http.MethodDelete) } if tr.URL.Path != "/project/test-project/instanceId/test-iid" { t.Errorf("Path = %q; want = %q", tr.URL.Path, "/project/test-project/instanceId/test-iid") } if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") } xGoogAPIClientHeader := internal.GetMetricsHeader(testIIDConfig.Version) if h := tr.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { t.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } } func TestDeleteInstanceIDError(t *testing.T) { status := http.StatusOK var tr *http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r w.WriteHeader(status) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testIIDConfig) if err != nil { t.Fatal(err) } client.endpoint = ts.URL client.client.RetryConfig = nil errorHandlers := map[int]func(error) bool{ http.StatusBadRequest: errorutils.IsInvalidArgument, http.StatusUnauthorized: errorutils.IsUnauthenticated, http.StatusForbidden: errorutils.IsPermissionDenied, http.StatusNotFound: errorutils.IsNotFound, http.StatusConflict: errorutils.IsConflict, http.StatusTooManyRequests: errorutils.IsResourceExhausted, http.StatusInternalServerError: errorutils.IsInternal, http.StatusServiceUnavailable: errorutils.IsUnavailable, } deprecatedErrorHandlers := map[int]func(error) bool{ http.StatusBadRequest: IsInvalidArgument, http.StatusUnauthorized: IsInsufficientPermission, http.StatusForbidden: IsInsufficientPermission, http.StatusNotFound: IsNotFound, http.StatusConflict: IsAlreadyDeleted, http.StatusTooManyRequests: IsTooManyRequests, http.StatusInternalServerError: IsInternal, http.StatusServiceUnavailable: IsServerUnavailable, } for code, check := range errorHandlers { status = code want := fmt.Sprintf("instance id %q: %s", "test-iid", errorMessages[code]) err := client.DeleteInstanceID(ctx, "test-iid") if err == nil || !check(err) || err.Error() != want { t.Errorf("DeleteInstanceID() = %v; want = %v", err, want) } resp := errorutils.HTTPResponse(err) if resp.StatusCode != code { t.Errorf("HTTPResponse().StatusCode = %d; want = %d", resp.StatusCode, code) } deprecatedCheck := deprecatedErrorHandlers[code] if !deprecatedCheck(err) { t.Errorf("DeleteInstanceID() = %v; want = %v", err, want) } if tr == nil { t.Fatalf("Request = nil; want non-nil") } if tr.Method != http.MethodDelete { t.Errorf("Method = %q; want = %q", tr.Method, http.MethodDelete) } if tr.URL.Path != "/project/test-project/instanceId/test-iid" { t.Errorf("Path = %q; want = %q", tr.URL.Path, "/project/test-project/instanceId/test-iid") } if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") } xGoogAPIClientHeader := internal.GetMetricsHeader(testIIDConfig.Version) if h := tr.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { t.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } tr = nil } } func TestDeleteInstanceIDUnexpectedError(t *testing.T) { var tr *http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r w.WriteHeader(511) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testIIDConfig) if err != nil { t.Fatal(err) } client.endpoint = ts.URL want := "unexpected http response with status: 511\n{}" err = client.DeleteInstanceID(ctx, "test-iid") if err == nil || err.Error() != want { t.Errorf("DeleteInstanceID() = %v; want = %v", err, want) } if !IsUnknown(err) { t.Errorf("IsUnknown() = false; want = true") } if !errorutils.IsUnknown(err) { t.Errorf("errorutils.IsUnknown() = false; want = true") } if tr == nil { t.Fatalf("Request = nil; want non-nil") } if tr.Method != http.MethodDelete { t.Errorf("Method = %q; want = %q", tr.Method, http.MethodDelete) } if tr.URL.Path != "/project/test-project/instanceId/test-iid" { t.Errorf("Path = %q; want = %q", tr.URL.Path, "/project/test-project/instanceId/test-iid") } if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") } xGoogAPIClientHeader := internal.GetMetricsHeader(testIIDConfig.Version) if h := tr.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { t.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } } func TestDeleteInstanceIDConnectionError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Do nothing })) ts.Close() ctx := context.Background() client, err := NewClient(ctx, testIIDConfig) if err != nil { t.Fatal(err) } client.endpoint = ts.URL client.client.RetryConfig = nil if err := client.DeleteInstanceID(ctx, "test-iid"); err == nil { t.Fatalf("DeleteInstanceID() = nil; want = error") } } golang-google-firebase-go-4.18.0/integration/000077500000000000000000000000001505612111400210255ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/auth/000077500000000000000000000000001505612111400217665ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/auth/auth_test.go000066400000000000000000000246241505612111400243250ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package auth contains integration tests for the firebase.google.com/go/auth package. package auth import ( "bytes" "context" "crypto/hmac" "crypto/sha1" "crypto/sha256" "encoding/json" "flag" "fmt" "io/ioutil" "log" "math/rand" "net/http" "os" "testing" "time" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/auth" "firebase.google.com/go/v4/auth/hash" "firebase.google.com/go/v4/integration/internal" "golang.org/x/oauth2/google" "google.golang.org/api/option" ) const ( verifyCustomTokenURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken?key=%s" verifyPasswordURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyPassword?key=%s" ) var client *auth.Client var apiKey string func TestMain(m *testing.M) { flag.Parse() if testing.Short() { log.Println("skipping auth integration tests in short mode.") os.Exit(0) } app, err := internal.NewTestApp(context.Background(), nil) if err != nil { log.Fatalln(err) } client, err = app.Auth(context.Background()) if err != nil { log.Fatalln(err) } apiKey, err = internal.APIKey() if err != nil { log.Fatalln(err) } seed := time.Now().UTC().UnixNano() log.Printf("Using random seed: %d", seed) rand.Seed(seed) os.Exit(m.Run()) } func TestCustomToken(t *testing.T) { uid := randomUID() ct, err := client.CustomToken(context.Background(), uid) if err != nil { t.Fatal(err) } verifyCustomToken(t, ct, uid) } func TestCustomTokenWithoutServiceAccount(t *testing.T) { // Create a TokenSource from the service account. This makes the private key not accessible // to the Firebase APIs. b, err := ioutil.ReadFile(internal.Resource("integration_cert.json")) if err != nil { t.Fatal(err) } jwtConfig, err := google.JWTConfigFromJSON(b, "https://www.googleapis.com/auth/cloud-platform") if err != nil { t.Fatal(err) } appConfig := &firebase.Config{ ServiceAccountID: jwtConfig.Email, } opt := option.WithTokenSource(jwtConfig.TokenSource(context.Background())) app, err := firebase.NewApp(context.Background(), appConfig, opt) if err != nil { t.Fatal(err) } otherClient, err := app.Auth(context.Background()) if err != nil { t.Fatal(err) } uid := randomUID() ct, err := otherClient.CustomToken(context.Background(), uid) if err != nil { t.Fatal(err) } verifyCustomToken(t, ct, uid) } func TestCustomTokenWithClaims(t *testing.T) { uid := randomUID() ct, err := client.CustomTokenWithClaims(context.Background(), uid, map[string]interface{}{ "premium": true, "package": "gold", }) if err != nil { t.Fatal(err) } vt := verifyCustomToken(t, ct, uid) if premium, ok := vt.Claims["premium"].(bool); !ok || !premium { t.Errorf("Claims['premium'] = %v; want Claims['premium'] = true", vt.Claims["premium"]) } if pkg, ok := vt.Claims["package"].(string); !ok || pkg != "gold" { t.Errorf("Claims['package'] = %v; want Claims['package'] = \"gold\"", vt.Claims["package"]) } } func TestRevokeRefreshTokens(t *testing.T) { uid := "user_revoked" ct, err := client.CustomToken(context.Background(), uid) if err != nil { t.Fatal(err) } idt, err := signInWithCustomToken(ct) if err != nil { t.Fatal(err) } defer deleteUser(uid) vt, err := client.VerifyIDTokenAndCheckRevoked(context.Background(), idt) if err != nil { t.Fatal(err) } if vt.UID != uid { t.Errorf("UID = %q; want UID = %q", vt.UID, uid) } // The backend stores the validSince property in seconds since the epoch. // The issuedAt property of the token is also in seconds. If a token was // issued, and then in the same second tokens were revoked, the token will // have the same timestamp as the tokensValidAfterMillis, and will therefore // not be considered revoked. Hence we wait one second before revoking. time.Sleep(time.Second) if err = client.RevokeRefreshTokens(context.Background(), uid); err != nil { t.Fatal(err) } vt, err = client.VerifyIDTokenAndCheckRevoked(context.Background(), idt) we := "ID token has been revoked" if vt != nil || err == nil || err.Error() != we { t.Errorf("tok, err := VerifyIDTokenAndCheckRevoked(); got (%v, %s) ; want (%v, %v)", vt, err, nil, we) } // Does not return error for revoked token. if _, err = client.VerifyIDToken(context.Background(), idt); err != nil { t.Errorf("VerifyIDToken(); err = %s; want err = ", err) } // Sign in after revocation. if idt, err = signInWithCustomToken(ct); err != nil { t.Fatal(err) } if _, err = client.VerifyIDTokenAndCheckRevoked(context.Background(), idt); err != nil { t.Errorf("VerifyIDTokenAndCheckRevoked(); err = %s; want err = ", err) } } func TestIDTokenForDisabledUser(t *testing.T) { uid := "user_disabled" ct, err := client.CustomToken(context.Background(), uid) if err != nil { t.Fatal(err) } idt, err := signInWithCustomToken(ct) if err != nil { t.Fatal(err) } defer deleteUser(uid) vt, err := client.VerifyIDTokenAndCheckRevoked(context.Background(), idt) if err != nil { t.Fatal(err) } if vt.UID != uid { t.Errorf("UID = %q; want UID = %q", vt.UID, uid) } // Disable the user updates := auth.UserToUpdate{} updates.Disabled(true) _, err = client.UpdateUser(context.Background(), uid, &updates) if err != nil { t.Fatalf("failed to disable user with UpdateUser: %v", err) } vt, err = client.VerifyIDTokenAndCheckRevoked(context.Background(), idt) we := "user has been disabled" if vt != nil || err == nil || !auth.IsUserDisabled(err) || err.Error() != we { t.Errorf("tok, err := VerifyIDTokenAndCheckRevoked(); got (%v, %s) ; want (%v, %v)", vt, err, nil, we) } } // verifyCustomToken verifies the given custom token by signing into a Firebase project with it. // // A successful sign in creates the user account in the Firebase back-end. This method ensures that // such user accounts are automatically deleted upon return. func verifyCustomToken(t *testing.T, ct, uid string) *auth.Token { idt, err := signInWithCustomToken(ct) if err != nil { t.Fatal(err) } defer deleteUser(uid) vt, err := client.VerifyIDToken(context.Background(), idt) if err != nil { t.Fatal(err) } if vt.UID != uid { t.Errorf("UID = %q; want UID = %q", vt.UID, uid) } if vt.Firebase.Tenant != "" { t.Errorf("Tenant = %q; want = %q", vt.Firebase.Tenant, "") } return vt } func signInWithCustomToken(token string) (string, error) { return signInWithCustomTokenForTenant(token, "") } func signInWithCustomTokenForTenant(token string, tenantID string) (string, error) { payload := map[string]interface{}{ "token": token, "returnSecureToken": true, } if tenantID != "" { payload["tenantId"] = tenantID } req, err := json.Marshal(payload) if err != nil { return "", err } resp, err := postRequest(fmt.Sprintf(verifyCustomTokenURL, apiKey), req) if err != nil { return "", err } var respBody struct { IDToken string `json:"idToken"` } if err := json.Unmarshal(resp, &respBody); err != nil { return "", err } return respBody.IDToken, err } func signInWithPassword(email, password string) (string, error) { req, err := json.Marshal(map[string]interface{}{ "email": email, "password": password, "returnSecureToken": true, }) if err != nil { return "", err } resp, err := postRequest(fmt.Sprintf(verifyPasswordURL, apiKey), req) if err != nil { return "", err } var respBody struct { IDToken string `json:"idToken"` } if err := json.Unmarshal(resp, &respBody); err != nil { return "", err } return respBody.IDToken, err } func TestImportUserPasswordSaltOrder(t *testing.T) { const ( password = "pass123123" key = "skeleton" salt = "NaCl" ) tests := []struct { name string hashConfig auth.UserImportHash localHash func() []byte }{ { name: "SHA1_SaltFirst", hashConfig: hash.SHA1{ Rounds: 1, InputOrder: hash.InputOrderSaltFirst, }, localHash: func() []byte { h := sha1.New() h.Write([]byte(salt + password)) return h.Sum(nil) }, }, { name: "HMAC_SHA256_PasswordFirst", hashConfig: hash.HMACSHA256{ Key: []byte(key), InputOrder: hash.InputOrderPasswordFirst, }, localHash: func() []byte { h := hmac.New(sha256.New, []byte(key)) h.Write([]byte(password + salt)) return h.Sum(nil) }, }, } for _, test := range tests { uid := randomUID() email := randomEmail(uid) user := (&auth.UserToImport{}). UID(uid). Email(email). PasswordHash(test.localHash()). PasswordSalt([]byte(salt)) result, err := client.ImportUsers(context.Background(), []*auth.UserToImport{user}, auth.WithHash(test.hashConfig)) if err != nil { t.Fatal(err) } defer deleteUser(uid) if result.SuccessCount != 1 || result.FailureCount != 0 { t.Errorf("ImportUsers(%s) = %#v; want = {SuccessCount: 1, FailureCount: 0}", test.name, result) } savedUser, err := client.GetUser(context.Background(), uid) if err != nil { t.Fatal(err) } if savedUser.Email != email { t.Errorf("GetUser(imported) = %q; want = %q", savedUser.Email, email) } idToken, err := signInWithPassword(email, "pass123123") if err != nil { t.Errorf("Sign in failed with %+v\nError: %s", test, err) continue } if idToken == "" { t.Errorf("ID Token = empty; want = non-empty") } } } func postRequest(url string, req []byte) ([]byte, error) { resp, err := http.Post(url, "application/json", bytes.NewBuffer(req)) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("unexpected http status code: %d", resp.StatusCode) } return ioutil.ReadAll(resp.Body) } // deleteUser makes a best effort attempt to delete the given user. // // Any errors encountered during the delete are logged and ignored. func deleteUser(uid string) { if err := client.DeleteUser(context.Background(), uid); err != nil { log.Printf("WARN: Failed to delete user %q on tear down: %v", uid, err) } } golang-google-firebase-go-4.18.0/integration/auth/project_config_mgt_test.go000066400000000000000000000033421505612111400272200ustar00rootroot00000000000000// Copyright 2023 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "reflect" "testing" "firebase.google.com/go/v4/auth" ) func TestProjectConfig(t *testing.T) { mfaObject := &auth.MultiFactorConfig{ ProviderConfigs: []*auth.ProviderConfig{ { State: auth.Enabled, TOTPProviderConfig: &auth.TOTPProviderConfig{ AdjacentIntervals: 5, }, }, }, } want := &auth.ProjectConfig{ MultiFactorConfig: mfaObject, } t.Run("UpdateProjectConfig()", func(t *testing.T) { mfaConfigReq := *want.MultiFactorConfig req := (&auth.ProjectConfigToUpdate{}). MultiFactorConfig(mfaConfigReq) projectConfig, err := client.UpdateProjectConfig(context.Background(), req) if err != nil { t.Fatalf("UpdateProjectConfig() = %v", err) } if !reflect.DeepEqual(projectConfig, want) { t.Errorf("UpdateProjectConfig() = %#v; want = %#v", projectConfig, want) } }) t.Run("GetProjectConfig()", func(t *testing.T) { project, err := client.GetProjectConfig(context.Background()) if err != nil { t.Fatalf("GetProjectConfig() = %v", err) } if !reflect.DeepEqual(project, want) { t.Errorf("GetProjectConfig() = %v; want = %#v", project, want) } }) } golang-google-firebase-go-4.18.0/integration/auth/provider_config_test.go000066400000000000000000000316631505612111400265440ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "fmt" "log" "reflect" "testing" "firebase.google.com/go/v4/auth" "google.golang.org/api/iterator" ) var x509Certs = []string{ "-----BEGIN CERTIFICATE-----\nMIICZjCCAc+gAwIBAgIBADANBgkqhkiG9w0BAQ0FADBQMQswCQYDVQQGEwJ1czEL\nMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAPBgNVBAMMCGFjbWUuY29tMRIw\nEAYDVQQHDAlTdW5ueXZhbGUwHhcNMTgxMjA2MDc1MTUxWhcNMjgxMjAzMDc1MTUx\nWjBQMQswCQYDVQQGEwJ1czELMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAP\nBgNVBAMMCGFjbWUuY29tMRIwEAYDVQQHDAlTdW5ueXZhbGUwgZ8wDQYJKoZIhvcN\nAQEBBQADgY0AMIGJAoGBAKphmggjiVgqMLXyzvI7cKphscIIQ+wcv7Dld6MD4aKv\n7Jqr8ltujMxBUeY4LFEKw8Terb01snYpDotfilaG6NxpF/GfVVmMalzwWp0mT8+H\nyzyPj89mRcozu17RwuooR6n1ofXjGcBE86lqC21UhA3WVgjPOLqB42rlE9gPnZLB\nAgMBAAGjUDBOMB0GA1UdDgQWBBS0iM7WnbCNOnieOP1HIA+Oz/ML+zAfBgNVHSME\nGDAWgBS0iM7WnbCNOnieOP1HIA+Oz/ML+zAMBgNVHRMEBTADAQH/MA0GCSqGSIb3\nDQEBDQUAA4GBAF3jBgS+wP+K/jTupEQur6iaqS4UvXd//d4vo1MV06oTLQMTz+rP\nOSMDNwxzfaOn6vgYLKP/Dcy9dSTnSzgxLAxfKvDQZA0vE3udsw0Bd245MmX4+GOp\nlbrN99XP1u+lFxCSdMUzvQ/jW4ysw/Nq4JdJ0gPAyPvL6Qi/3mQdIQwx\n-----END CERTIFICATE-----\n", "-----BEGIN CERTIFICATE-----\nMIICZjCCAc+gAwIBAgIBADANBgkqhkiG9w0BAQ0FADBQMQswCQYDVQQGEwJ1czEL\nMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAPBgNVBAMMCGFjbWUuY29tMRIw\nEAYDVQQHDAlTdW5ueXZhbGUwHhcNMTgxMjA2MDc1ODE4WhcNMjgxMjAzMDc1ODE4\nWjBQMQswCQYDVQQGEwJ1czELMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAP\nBgNVBAMMCGFjbWUuY29tMRIwEAYDVQQHDAlTdW5ueXZhbGUwgZ8wDQYJKoZIhvcN\nAQEBBQADgY0AMIGJAoGBAKuzYKfDZGA6DJgQru3wNUqv+S0hMZfP/jbp8ou/8UKu\nrNeX7cfCgt3yxoGCJYKmF6t5mvo76JY0MWwA53BxeP/oyXmJ93uHG5mFRAsVAUKs\ncVVb0Xi6ujxZGVdDWFV696L0BNOoHTfXmac6IBoZQzNNK4n1AATqwo+z7a0pfRrJ\nAgMBAAGjUDBOMB0GA1UdDgQWBBSKmi/ZKMuLN0ES7/jPa7q7jAjPiDAfBgNVHSME\nGDAWgBSKmi/ZKMuLN0ES7/jPa7q7jAjPiDAMBgNVHRMEBTADAQH/MA0GCSqGSIb3\nDQEBDQUAA4GBAAg2a2kSn05NiUOuWOHwPUjW3wQRsGxPXtbhWMhmNdCfKKteM2+/\nLd/jz5F3qkOgGQ3UDgr3SHEoWhnLaJMF4a2tm6vL2rEIfPEK81KhTTRxSsAgMVbU\nJXBz1md6Ur0HlgQC7d1CHC8/xi2DDwHopLyxhogaZUxy9IaRxUEa2vJW\n-----END CERTIFICATE-----\n", } func TestOIDCProviderConfig(t *testing.T) { testOIDCProviderConfig(t, client) } type oidcProviderClient interface { OIDCProviderConfig(ctx context.Context, id string) (*auth.OIDCProviderConfig, error) OIDCProviderConfigs(ctx context.Context, nextPageToken string) *auth.OIDCProviderConfigIterator CreateOIDCProviderConfig(ctx context.Context, config *auth.OIDCProviderConfigToCreate) (*auth.OIDCProviderConfig, error) UpdateOIDCProviderConfig(ctx context.Context, id string, config *auth.OIDCProviderConfigToUpdate) (*auth.OIDCProviderConfig, error) DeleteOIDCProviderConfig(ctx context.Context, id string) error } func testOIDCProviderConfig(t *testing.T, client oidcProviderClient) { id := randomOIDCProviderID() want := &auth.OIDCProviderConfig{ ID: id, DisplayName: "OIDC_DISPLAY_NAME", Enabled: true, ClientID: "OIDC_CLIENT_ID", Issuer: "https://oidc.com/issuer", IDTokenResponseType: true, } req := (&auth.OIDCProviderConfigToCreate{}). ID(id). DisplayName("OIDC_DISPLAY_NAME"). Enabled(true). ClientID("OIDC_CLIENT_ID"). Issuer("https://oidc.com/issuer") created, err := client.CreateOIDCProviderConfig(context.Background(), req) if err != nil { t.Fatalf("CreateOIDCProviderConfig() = %v", err) } // Clean up action in the event of a panic defer func() { if id == "" { return } if err := client.DeleteOIDCProviderConfig(context.Background(), id); err != nil { log.Printf("WARN: failed to delete OIDC provider config %q on tear down: %v", id, err) } }() t.Run("CreateOIDCProviderConfig()", func(t *testing.T) { if !reflect.DeepEqual(created, want) { t.Errorf("CreateOIDCProviderConfig() = %#v; want = %#v", created, want) } }) t.Run("OIDCProviderConfig()", func(t *testing.T) { oidc, err := client.OIDCProviderConfig(context.Background(), id) if err != nil { t.Fatalf("OIDCProviderConfig() = %v", err) } if !reflect.DeepEqual(oidc, want) { t.Errorf("OIDCProviderConfig() = %#v; want = %#v", oidc, want) } }) t.Run("OIDCProviderConfigs()", func(t *testing.T) { iter := client.OIDCProviderConfigs(context.Background(), "") var target *auth.OIDCProviderConfig for { oidc, err := iter.Next() if err == iterator.Done { break } else if err != nil { t.Fatalf("OIDCProviderConfigs() = %v", err) } if oidc.ID == id { target = oidc break } } if target == nil { t.Fatalf("OIDCProviderConfigs() did not return required config: %q", id) } if !reflect.DeepEqual(target, want) { t.Errorf("OIDCProviderConfigs() = %#v; want = %#v", target, want) } }) t.Run("UpdateOIDCProviderConfig()", func(t *testing.T) { want = &auth.OIDCProviderConfig{ ID: id, DisplayName: "UPDATED_OIDC_DISPLAY_NAME", ClientID: "UPDATED_OIDC_CLIENT_ID", Issuer: "https://oidc.com/updated_issuer", IDTokenResponseType: true, } req := (&auth.OIDCProviderConfigToUpdate{}). DisplayName("UPDATED_OIDC_DISPLAY_NAME"). Enabled(false). ClientID("UPDATED_OIDC_CLIENT_ID"). Issuer("https://oidc.com/updated_issuer") oidc, err := client.UpdateOIDCProviderConfig(context.Background(), id, req) if err != nil { t.Fatalf("UpdateOIDCProviderConfig() = %v", err) } if !reflect.DeepEqual(oidc, want) { t.Errorf("UpdateOIDCProviderConfig() = %#v; want = %#v", oidc, want) } }) t.Run("UpdateOIDCProviderConfig() should be rejected with invalid oauth response type", func(t *testing.T) { req := (&auth.OIDCProviderConfigToUpdate{}). DisplayName("UPDATED_OIDC_DISPLAY_NAME"). Enabled(false). ClientID("UPDATED_OIDC_CLIENT_ID"). Issuer("https://oidc.com/updated_issuer"). IDTokenResponseType(false). CodeResponseType(false). ClientSecret("CLIENT_SECRET") _, err := client.UpdateOIDCProviderConfig(context.Background(), id, req) if err == nil { t.Fatalf("UpdateOIDCProviderConfig(invalid_oauth_response_type) error nil; want not nil") } if err.Error() != "At least one response type must be returned" { t.Errorf( "UpdateOIDCProviderConfig(invalid_oauth_response_type) returned an error of '%s'; "+ "expected 'At least one response type must be returned'", err.Error()) } }) t.Run("UpdateOIDCProviderConfig() should be rejected code flow with no client secret", func(t *testing.T) { req := (&auth.OIDCProviderConfigToUpdate{}). DisplayName("UPDATED_OIDC_DISPLAY_NAME"). Enabled(false). ClientID("UPDATED_OIDC_CLIENT_ID"). Issuer("https://oidc.com/updated_issuer"). IDTokenResponseType(false). CodeResponseType(true) _, err := client.UpdateOIDCProviderConfig(context.Background(), id, req) if err == nil { t.Fatalf("UpdateOIDCProviderConfig(code_flow_with_no_client_secret) error nil; want not nil") } if err.Error() != "Client Secret must not be empty for Code Response Type" { t.Errorf( "UpdateOIDCProviderConfig(code_flow_with_no_client_secret) returned an error of '%s'; "+ "expected 'Client Secret must not be empty for Code Response Type'", err.Error()) } }) t.Run("DeleteOIDCProviderConfig()", func(t *testing.T) { if err := client.DeleteOIDCProviderConfig(context.Background(), id); err != nil { t.Fatalf("DeleteOIDCProviderConfig() = %v", err) } _, err := client.OIDCProviderConfig(context.Background(), id) if err == nil || !auth.IsConfigurationNotFound(err) { t.Errorf("OIDCProviderConfig() = %v; want = ConfigurationNotFound", err) } id = "" }) } func TestSAMLProviderConfig(t *testing.T) { testSAMLProviderConfig(t, client) } type samlProviderClient interface { SAMLProviderConfig(ctx context.Context, id string) (*auth.SAMLProviderConfig, error) SAMLProviderConfigs(ctx context.Context, nextPageToken string) *auth.SAMLProviderConfigIterator CreateSAMLProviderConfig(ctx context.Context, config *auth.SAMLProviderConfigToCreate) (*auth.SAMLProviderConfig, error) UpdateSAMLProviderConfig(ctx context.Context, id string, config *auth.SAMLProviderConfigToUpdate) (*auth.SAMLProviderConfig, error) DeleteSAMLProviderConfig(ctx context.Context, id string) error } func testSAMLProviderConfig(t *testing.T, client samlProviderClient) { id := randomSAMLProviderID() want := &auth.SAMLProviderConfig{ ID: id, DisplayName: "SAML_DISPLAY_NAME", Enabled: true, IDPEntityID: "IDP_ENTITY_ID", SSOURL: "https://example.com/login", X509Certificates: []string{ x509Certs[0], }, RPEntityID: "RP_ENTITY_ID", CallbackURL: "https://projectId.firebaseapp.com/__/auth/handler", RequestSigningEnabled: true, } req := (&auth.SAMLProviderConfigToCreate{}). ID(id). DisplayName("SAML_DISPLAY_NAME"). Enabled(true). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/login"). X509Certificates([]string{x509Certs[0]}). RPEntityID("RP_ENTITY_ID"). CallbackURL("https://projectId.firebaseapp.com/__/auth/handler"). RequestSigningEnabled(true) created, err := client.CreateSAMLProviderConfig(context.Background(), req) if err != nil { t.Fatalf("CreateSAMLProviderConfig() = %v", err) } // Clean up action in the event of a panic defer func() { if id == "" { return } if err := client.DeleteSAMLProviderConfig(context.Background(), id); err != nil { log.Printf("WARN: failed to delete SAML provider config %q on tear down: %v", id, err) } }() t.Run("CreateSAMLProviderConfig()", func(t *testing.T) { if !reflect.DeepEqual(created, want) { t.Errorf("CreateSAMLProviderConfig() = %#v; want = %#v", created, want) } }) t.Run("SAMLProviderConfig()", func(t *testing.T) { saml, err := client.SAMLProviderConfig(context.Background(), id) if err != nil { t.Fatalf("SAMLProviderConfig() = %v", err) } if !reflect.DeepEqual(saml, want) { t.Errorf("SAMLProviderConfig() = %#v; want = %#v", saml, want) } }) t.Run("SAMLProviderConfigs()", func(t *testing.T) { iter := client.SAMLProviderConfigs(context.Background(), "") var target *auth.SAMLProviderConfig for { saml, err := iter.Next() if err == iterator.Done { break } else if err != nil { t.Fatalf("SAMLProviderConfigs() = %v", err) } if saml.ID == id { target = saml break } } if target == nil { t.Fatalf("SAMLProviderConfigs() did not return required config: %q", id) } if !reflect.DeepEqual(target, want) { t.Errorf("SAMLProviderConfigs() = %#v; want = %#v", target, want) } }) t.Run("UpdateSAMLProviderConfig()", func(t *testing.T) { want = &auth.SAMLProviderConfig{ ID: id, DisplayName: "UPDATED_SAML_DISPLAY_NAME", IDPEntityID: "UPDATED_IDP_ENTITY_ID", SSOURL: "https://example.com/updated_login", X509Certificates: []string{ x509Certs[1], }, RPEntityID: "UPDATED_RP_ENTITY_ID", CallbackURL: "https://updatedProjectId.firebaseapp.com/__/auth/handler", } req := (&auth.SAMLProviderConfigToUpdate{}). DisplayName("UPDATED_SAML_DISPLAY_NAME"). Enabled(false). IDPEntityID("UPDATED_IDP_ENTITY_ID"). SSOURL("https://example.com/updated_login"). X509Certificates([]string{x509Certs[1]}). RPEntityID("UPDATED_RP_ENTITY_ID"). CallbackURL("https://updatedProjectId.firebaseapp.com/__/auth/handler"). RequestSigningEnabled(false) saml, err := client.UpdateSAMLProviderConfig(context.Background(), id, req) if err != nil { t.Fatalf("UpdateSAMLProviderConfig() = %v", err) } if !reflect.DeepEqual(saml, want) { t.Errorf("UpdateSAMLProviderConfig() = %#v; want = %#v", saml, want) } }) t.Run("DeleteSAMLProviderConfig()", func(t *testing.T) { if err := client.DeleteSAMLProviderConfig(context.Background(), id); err != nil { t.Fatalf("DeleteSAMLProviderConfig() = %v", err) } _, err := client.SAMLProviderConfig(context.Background(), id) if err == nil || !auth.IsConfigurationNotFound(err) { t.Errorf("SAMLProviderConfig() = %v; want = ConfigurationNotFound", err) } id = "" }) } func randomSAMLProviderID() string { return fmt.Sprintf("saml.%s", randomCharacterString()) } func randomOIDCProviderID() string { return fmt.Sprintf("oidc.%s", randomCharacterString()) } func randomCharacterString() string { var letters = []rune("abcdefghijklmnopqrstuvwxyz") b := make([]rune, 10) for i := range b { b[i] = letters[seededRand.Intn(len(letters))] } return string(b) } func deleteSAMLProviderConfig(id string) { if err := client.DeleteSAMLProviderConfig(context.Background(), id); err != nil { log.Printf("WARN: failed to delete SAML provider config %q on tear down: %v", id, err) } } golang-google-firebase-go-4.18.0/integration/auth/tenant_mgt_test.go000066400000000000000000000265541505612111400255300ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "context" "log" "net/url" "reflect" "testing" "time" "firebase.google.com/go/v4/auth" "google.golang.org/api/iterator" ) func TestTenantManager(t *testing.T) { mfaObject := &auth.MultiFactorConfig{ ProviderConfigs: []*auth.ProviderConfig{ { State: auth.Enabled, TOTPProviderConfig: &auth.TOTPProviderConfig{ AdjacentIntervals: 5, }, }, }, } want := &auth.Tenant{ DisplayName: "admin-go-tenant", AllowPasswordSignUp: true, EnableEmailLinkSignIn: true, EnableAnonymousUsers: true, MultiFactorConfig: mfaObject, } req := (&auth.TenantToCreate{}). DisplayName("admin-go-tenant"). AllowPasswordSignUp(true). EnableEmailLinkSignIn(true). EnableAnonymousUsers(true). MultiFactorConfig(*mfaObject) created, err := client.TenantManager.CreateTenant(context.Background(), req) if err != nil { t.Fatalf("CreateTenant() = %v", err) } id := created.ID want.ID = id // Clean up action in the event of a panic defer func() { if id == "" { return } if err := client.TenantManager.DeleteTenant(context.Background(), id); err != nil { log.Printf("WARN: failed to delete tenant %q on tear down: %v", id, err) } }() t.Run("CreateTenant()", func(t *testing.T) { if !reflect.DeepEqual(created, want) { t.Errorf("CreateTenant() = %#v; want = %#v", created, want) } }) t.Run("Tenant()", func(t *testing.T) { tenant, err := client.TenantManager.Tenant(context.Background(), id) if err != nil { t.Fatalf("Tenant() = %v", err) } if !reflect.DeepEqual(tenant, want) { t.Errorf("Tenant() = %#v; want = %#v", tenant, want) } }) t.Run("Tenants()", func(t *testing.T) { iter := client.TenantManager.Tenants(context.Background(), "") var target *auth.Tenant for { tenant, err := iter.Next() if err == iterator.Done { break } else if err != nil { t.Fatalf("Tenants() = %v", err) } if tenant.ID == id { target = tenant break } } if target == nil { t.Fatalf("Tenants() did not return required tenant: %q", id) } if !reflect.DeepEqual(target, want) { t.Errorf("Tenants() = %#v; want = %#v", target, want) } }) t.Run("CustomTokens", func(t *testing.T) { testTenantAwareCustomToken(t, id) }) t.Run("UserManagement", func(t *testing.T) { testTenantAwareUserManagement(t, id) }) t.Run("OIDCProviderConfig", func(t *testing.T) { tenantClient, err := client.TenantManager.AuthForTenant(id) if err != nil { t.Fatalf("AuthForTenant() = %v", err) } testOIDCProviderConfig(t, tenantClient) }) t.Run("SAMLProviderConfig", func(t *testing.T) { tenantClient, err := client.TenantManager.AuthForTenant(id) if err != nil { t.Fatalf("AuthForTenant() = %v", err) } testSAMLProviderConfig(t, tenantClient) }) t.Run("UpdateTenant()", func(t *testing.T) { mfaObject := &auth.MultiFactorConfig{ ProviderConfigs: []*auth.ProviderConfig{ { State: auth.Enabled, TOTPProviderConfig: &auth.TOTPProviderConfig{ AdjacentIntervals: 5, }, }, }, } want = &auth.Tenant{ ID: id, DisplayName: "updated-go-tenant", AllowPasswordSignUp: false, EnableEmailLinkSignIn: false, EnableAnonymousUsers: false, MultiFactorConfig: mfaObject, } req := (&auth.TenantToUpdate{}). DisplayName("updated-go-tenant"). AllowPasswordSignUp(false). EnableEmailLinkSignIn(false). EnableAnonymousUsers(false). MultiFactorConfig(*mfaObject) tenant, err := client.TenantManager.UpdateTenant(context.Background(), id, req) if err != nil { t.Fatalf("UpdateTenant() = %v", err) } if !reflect.DeepEqual(tenant, want) { t.Errorf("UpdateTenant() = %#v; want = %#v", tenant, want) } }) t.Run("DeleteTenant()", func(t *testing.T) { if err := client.TenantManager.DeleteTenant(context.Background(), id); err != nil { t.Fatalf("DeleteTenant() = %v", err) } _, err := client.TenantManager.Tenant(context.Background(), id) if err == nil || !auth.IsTenantNotFound(err) { t.Errorf("Tenant() = %v; want = TenantNotFound", err) } id = "" }) } func testTenantAwareCustomToken(t *testing.T, id string) { tenantClient, err := client.TenantManager.AuthForTenant(id) if err != nil { t.Fatalf("AuthForTenant() = %v", err) } uid := randomUID() ct, err := tenantClient.CustomToken(context.Background(), uid) if err != nil { t.Fatal(err) } idToken, err := signInWithCustomTokenForTenant(ct, id) if err != nil { t.Fatal(err) } defer func() { tenantClient.DeleteUser(context.Background(), uid) }() vt, err := tenantClient.VerifyIDToken(context.Background(), idToken) if err != nil { t.Fatal(err) } if vt.UID != uid { t.Errorf("UID = %q; want UID = %q", vt.UID, uid) } if vt.Firebase.Tenant != id { t.Errorf("Tenant = %q; want = %q", vt.Firebase.Tenant, id) } } func testTenantAwareUserManagement(t *testing.T, id string) { tenantClient, err := client.TenantManager.AuthForTenant(id) if err != nil { t.Fatalf("AuthForTenant() = %v", err) } user, err := tenantClient.CreateUser(context.Background(), nil) if err != nil { t.Fatalf("CreateUser() = %v", err) } t.Run("CreateUser()", func(t *testing.T) { if user.TenantID != id { t.Errorf("CreateUser().TenantID = %q; want = %q", user.TenantID, id) } }) want := auth.UserInfo{ UID: user.UID, Email: randomEmail(user.UID), PhoneNumber: randomPhoneNumber(), ProviderID: "firebase", } t.Run("UpdateUser()", func(t *testing.T) { req := (&auth.UserToUpdate{}). Email(want.Email). PhoneNumber(want.PhoneNumber) updated, err := tenantClient.UpdateUser(context.Background(), user.UID, req) if err != nil { t.Fatalf("UpdateUser() = %v", err) } if updated.TenantID != id { t.Errorf("UpdateUser().TenantID = %q; want = %q", updated.TenantID, id) } if !reflect.DeepEqual(*updated.UserInfo, want) { t.Errorf("UpdateUser() = %v; want = %v", *updated.UserInfo, want) } }) t.Run("GetUser()", func(t *testing.T) { got, err := tenantClient.GetUser(context.Background(), user.UID) if err != nil { t.Fatalf("GetUser() = %v", err) } if got.TenantID != id { t.Errorf("GetUser().TenantID = %q; want = %q", got.TenantID, id) } if !reflect.DeepEqual(*got.UserInfo, want) { t.Errorf("GetUser() = %v; want = %v", *got.UserInfo, want) } }) t.Run("Users()", func(t *testing.T) { iter := tenantClient.Users(context.Background(), "") var target *auth.ExportedUserRecord for { got, err := iter.Next() if err == iterator.Done { break } else if err != nil { t.Fatalf("Users() = %v", err) } if got.UID == user.UID { target = got break } } if target == nil { t.Fatalf("Users() did not return required user: %q", user.UID) } if !reflect.DeepEqual(*target.UserInfo, want) { t.Errorf("Users() = %v; want = %v", *target.UserInfo, want) } }) t.Run("SetCustomUserClaims()", func(t *testing.T) { claims := map[string]interface{}{ "premium": true, "role": "customer", } if err := tenantClient.SetCustomUserClaims(context.Background(), user.UID, claims); err != nil { t.Fatalf("SetCustomUserClaims() = %v", err) } got, err := tenantClient.GetUser(context.Background(), user.UID) if err != nil { t.Fatalf("GetUser() = %v", err) } if !reflect.DeepEqual(got.CustomClaims, claims) { t.Errorf("CustomClaims = %v; want = %v", got.CustomClaims, claims) } }) t.Run("EmailVerificationLink()", func(t *testing.T) { link, err := tenantClient.EmailVerificationLink(context.Background(), want.Email) if err != nil { t.Fatalf("EmailVerificationLink() = %v", err) } tenant, err := extractTenantID(link) if err != nil { t.Fatalf("EmailVerificationLink() = %v", err) } if id != tenant { t.Fatalf("EmailVerificationLink() TenantID = %q; want = %q", tenant, id) } }) t.Run("PasswordResetLink()", func(t *testing.T) { link, err := tenantClient.PasswordResetLink(context.Background(), want.Email) if err != nil { t.Fatalf("PasswordResetLink() = %v", err) } tenant, err := extractTenantID(link) if err != nil { t.Fatalf("PasswordResetLink() = %v", err) } if id != tenant { t.Fatalf("PasswordResetLink() TenantID = %q; want = %q", tenant, id) } }) t.Run("EmailSignInLink()", func(t *testing.T) { link, err := tenantClient.EmailSignInLink(context.Background(), want.Email, &auth.ActionCodeSettings{ URL: continueURL, HandleCodeInApp: false, }) if err != nil { t.Fatalf("EmailSignInLink() = %v", err) } tenant, err := extractTenantID(link) if err != nil { t.Fatalf("EmailSignInLink() = %v", err) } if id != tenant { t.Fatalf("EmailSignInLink() TenantID = %q; want = %q", tenant, id) } }) t.Run("RevokeRefreshTokens()", func(t *testing.T) { validSinceMillis := time.Now().Unix() * 1000 time.Sleep(1 * time.Second) if err := tenantClient.RevokeRefreshTokens(context.Background(), user.UID); err != nil { t.Fatalf("RevokeRefreshTokens() = %v", err) } got, err := tenantClient.GetUser(context.Background(), user.UID) if err != nil { t.Fatalf("GetUser() = %v", err) } if got.TokensValidAfterMillis < validSinceMillis { t.Fatalf("RevokeRefreshTokens() TokensValidAfterMillis (%d) < Now (%d)", got.TokensValidAfterMillis, validSinceMillis) } }) t.Run("ImportUsers()", func(t *testing.T) { scrypt, passwordHash, err := newScryptHash() if err != nil { t.Fatalf("newScryptHash() = %v", err) } uid := randomUID() email := randomEmail(uid) user := (&auth.UserToImport{}). UID(uid). Email(email). PasswordHash(passwordHash). PasswordSalt([]byte("NaCl")) result, err := tenantClient.ImportUsers(context.Background(), []*auth.UserToImport{user}, auth.WithHash(scrypt)) if err != nil { t.Fatalf("ImportUsers() = %v", err) } defer func() { tenantClient.DeleteUser(context.Background(), uid) }() if result.SuccessCount != 1 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 1, FailureCount: 0}", result) } savedUser, err := tenantClient.GetUser(context.Background(), uid) if err != nil { t.Fatalf("GetUser() = %v", err) } if savedUser.Email != email { t.Errorf("ImportUser() Email = %q; want = %q", savedUser.Email, email) } if savedUser.TenantID != id { t.Errorf("ImportUser() TenantID = %q; want = %q", savedUser.TenantID, id) } }) t.Run("DeleteUser()", func(t *testing.T) { if err := tenantClient.DeleteUser(context.Background(), user.UID); err != nil { t.Fatalf("DeleteUser() = %v", err) } _, err = tenantClient.GetUser(context.Background(), user.UID) if err == nil || !auth.IsUserNotFound(err) { t.Errorf("Tenant() = %v; want = UserNotFound", err) } }) } func extractTenantID(actionLink string) (string, error) { u, err := url.Parse(actionLink) if err != nil { return "", err } q := u.Query() return q.Get("tenantId"), nil } golang-google-firebase-go-4.18.0/integration/auth/user_mgt_test.go000066400000000000000000001247041505612111400252110ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package auth contains integration tests for the firebase.google.com/go/auth package. package auth import ( "context" "encoding/base64" "encoding/json" "fmt" "math/rand" "net/url" "reflect" "sort" "strings" "testing" "time" "firebase.google.com/go/v4/auth" "firebase.google.com/go/v4/auth/hash" "google.golang.org/api/iterator" ) const ( continueURL = "http://localhost/?a=1&b=2#c=3" continueURLKey = "continueUrl" oobCodeKey = "oobCode" modeKey = "mode" resetPasswordURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/resetPassword?key=%s" emailLinkSignInURL = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/emailLinkSignin?key=%s" ) func TestGetUser(t *testing.T) { want := newUserWithParams(t) defer deleteUser(want.UID) cases := []struct { name string getOp func(context.Context) (*auth.UserRecord, error) }{ { "GetUser()", func(ctx context.Context) (*auth.UserRecord, error) { return client.GetUser(ctx, want.UID) }, }, { "GetUserByEmail()", func(ctx context.Context) (*auth.UserRecord, error) { return client.GetUserByEmail(ctx, want.Email) }, }, { "GetUserByPhoneNumber()", func(ctx context.Context) (*auth.UserRecord, error) { return client.GetUserByPhoneNumber(ctx, want.PhoneNumber) }, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { got, err := tc.getOp(context.Background()) if err != nil || !reflect.DeepEqual(*got, *want) { t.Errorf("%s = (%#v, %v); want = (%#v, nil)", tc.name, got, err, want) } }) } } func TestGetUserByProviderUID(t *testing.T) { // TODO(rsgowman): Once we can link a provider id with a user, just do that // here instead of importing a new user. importUserUID := randomUID() providerUID := "google_" + importUserUID userToImport := (&auth.UserToImport{}). UID(importUserUID). Email(randomEmail(importUserUID)). PhoneNumber(randomPhoneNumber()). ProviderData([](*auth.UserProvider){ &auth.UserProvider{ ProviderID: "google.com", UID: providerUID, }, }) importUser(t, importUserUID, userToImport) defer deleteUser(importUserUID) userRecord, err := client.GetUserByProviderUID(context.Background(), "google.com", providerUID) if err != nil { t.Fatalf("GetUserByProviderUID() = %q", err) } if userRecord.UID != importUserUID { t.Errorf("GetUserByProviderUID().UID = %v; want = %v", userRecord.UID, importUserUID) } } func TestGetNonExistingUser(t *testing.T) { user, err := client.GetUser(context.Background(), "non.existing") if user != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) } user, err = client.GetUserByEmail(context.Background(), "non.existing@definitely.non.existing") if user != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUserByEmail(non.existing) = (%v, %v); want = (nil, error)", user, err) } user, err = client.GetUserByPhoneNumber(context.Background(), "+14044040404") if user != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) } user, err = client.GetUserByProviderUID(context.Background(), "google.com", "a-uid-that-doesnt-exist") if user != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUser(non.existing) = (%v, %v); want = (nil, error)", user, err) } } func TestGetUsers(t *testing.T) { // Checks to see if the users list contain the given uids. Order is ignored. // // Behaviour is undefined if there are duplicate entries in either of the // slices. // // This function is identical to the one in auth/user_mgt_test.go sameUsers := func(users [](*auth.UserRecord), uids []string) bool { if len(users) != len(uids) { return false } sort.Slice(users, func(i, j int) bool { return users[i].UID < users[j].UID }) sort.Slice(uids, func(i, j int) bool { return uids[i] < uids[j] }) for i := range users { if users[i].UID != uids[i] { return false } } return true } testUser1 := newUserWithParams(t) defer deleteUser(testUser1.UID) testUser2 := newUserWithParams(t) defer deleteUser(testUser2.UID) testUser3 := newUserWithParams(t) defer deleteUser(testUser3.UID) importUser1UID := randomUID() importUser1 := (&auth.UserToImport{}). UID(importUser1UID). Email(randomEmail(importUser1UID)). PhoneNumber(randomPhoneNumber()). ProviderData([](*auth.UserProvider){ &auth.UserProvider{ ProviderID: "google.com", UID: "google_" + importUser1UID, }, }) importUser(t, importUser1UID, importUser1) defer deleteUser(importUser1UID) userRecordsToUIDs := func(users [](*auth.UserRecord)) []string { results := []string{} for i := range users { results = append(results, users[i].UID) } return results } t.Run("various identifier types", func(t *testing.T) { getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ auth.UIDIdentifier{UID: testUser1.UID}, auth.EmailIdentifier{Email: testUser2.Email}, auth.PhoneIdentifier{PhoneNumber: testUser3.PhoneNumber}, auth.ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_" + importUser1UID}, }) if err != nil { t.Fatalf("GetUsers() = %q", err) } if !sameUsers(getUsersResult.Users, []string{testUser1.UID, testUser2.UID, testUser3.UID, importUser1UID}) { t.Errorf("GetUsers() = %v; want = %v (in any order)", userRecordsToUIDs(getUsersResult.Users), []string{testUser1.UID, testUser2.UID, testUser3.UID, importUser1UID}) } }) t.Run("mix of existing and non-existing users", func(t *testing.T) { getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ auth.UIDIdentifier{UID: testUser1.UID}, auth.UIDIdentifier{UID: "uid_that_doesnt_exist"}, auth.UIDIdentifier{UID: testUser3.UID}, }) if err != nil { t.Fatalf("GetUsers() = %q", err) } if !sameUsers(getUsersResult.Users, []string{testUser1.UID, testUser3.UID}) { t.Errorf("GetUsers() = %v; want = %v (in any order)", getUsersResult.Users, []string{testUser1.UID, testUser3.UID}) } if len(getUsersResult.NotFound) != 1 { t.Errorf("len(GetUsers().NotFound) = %d; want 1", len(getUsersResult.NotFound)) } else { if getUsersResult.NotFound[0].(auth.UIDIdentifier).UID != "uid_that_doesnt_exist" { t.Errorf("GetUsers().NotFound[0].UID = %s; want 'uid_that_doesnt_exist'", getUsersResult.NotFound[0].(auth.UIDIdentifier).UID) } } }) t.Run("only non-existing users", func(t *testing.T) { getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ auth.UIDIdentifier{UID: "non-existing user"}, }) if err != nil { t.Fatalf("GetUsers() = %q", err) } if len(getUsersResult.Users) != 0 { t.Errorf("len(GetUsers().Users) = %d; want = 0", len(getUsersResult.Users)) } if len(getUsersResult.NotFound) != 1 { t.Errorf("len(GetUsers().NotFound) = %d; want = 1", len(getUsersResult.NotFound)) } else { if getUsersResult.NotFound[0].(auth.UIDIdentifier).UID != "non-existing user" { t.Errorf("GetUsers().NotFound[0].UID = %s; want 'non-existing user'", getUsersResult.NotFound[0].(auth.UIDIdentifier).UID) } } }) t.Run("de-dups duplicate users", func(t *testing.T) { getUsersResult, err := client.GetUsers(context.Background(), []auth.UserIdentifier{ auth.UIDIdentifier{UID: testUser1.UID}, auth.UIDIdentifier{UID: testUser1.UID}, }) if err != nil { t.Fatalf("GetUsers() returned an error: %v", err) } if len(getUsersResult.Users) != 1 { t.Errorf("len(GetUsers().Users) = %d; want = 1", len(getUsersResult.Users)) } else { if getUsersResult.Users[0].UID != testUser1.UID { t.Errorf("GetUsers().Users[0].UID = %s; want = '%s'", getUsersResult.Users[0].UID, testUser1.UID) } } if len(getUsersResult.NotFound) != 0 { t.Errorf("len(GetUsers().NotFound) = %d; want = 0", len(getUsersResult.NotFound)) } }) } func TestLastRefreshTime(t *testing.T) { userRecord := newUserWithParams(t) defer deleteUser(userRecord.UID) // New users should not have a LastRefreshTimestamp set. if userRecord.UserMetadata.LastRefreshTimestamp != 0 { t.Errorf( "CreateUser(...).UserMetadata.LastRefreshTimestamp = %d; want = 0", userRecord.UserMetadata.LastRefreshTimestamp) } // Login to cause the LastRefreshTimestamp to be set if _, err := signInWithPassword(userRecord.Email, "password"); err != nil { t.Errorf("signInWithPassword failed: %v", err) } // Attempt to retrieve the user 3 times (with a small delay between each attempt.) Occasionally, // this call retrieves the user data without the lastLoginTime/lastRefreshTime fields; possibly // because it's hitting a different server than what the login request used. var getUsersResult *auth.UserRecord for i := 0; i < 3; i++ { var err error getUsersResult, err = client.GetUser(context.Background(), userRecord.UID) if err != nil { t.Fatalf("GetUser(...) failed with error: %v", err) } if getUsersResult.UserMetadata.LastRefreshTimestamp != 0 { break } time.Sleep(time.Second * time.Duration(2^i)) } // Ensure last refresh time is approx now (with tollerance of 10m) nowMillis := time.Now().Unix() * 1000 lastRefreshTimestamp := getUsersResult.UserMetadata.LastRefreshTimestamp if lastRefreshTimestamp < nowMillis-10*60*1000 { t.Errorf("GetUser(...).UserMetadata.LastRefreshTimestamp = %d; want >= %d", lastRefreshTimestamp, nowMillis-10*60*1000) } if nowMillis+10*60*1000 < lastRefreshTimestamp { t.Errorf("GetUser(...).UserMetadata.LastRefreshTimestamp = %d; want <= %d", lastRefreshTimestamp, nowMillis+10*60*1000) } } func TestUpdateNonExistingUser(t *testing.T) { update := (&auth.UserToUpdate{}).Email("test@example.com") user, err := client.UpdateUser(context.Background(), "non.existing", update) if user != nil || !auth.IsUserNotFound(err) { t.Errorf("UpdateUser(non.existing) = (%v, %v); want = (nil, error)", user, err) } } func TestDeleteNonExistingUser(t *testing.T) { err := client.DeleteUser(context.Background(), "non.existing") if !auth.IsUserNotFound(err) { t.Errorf("DeleteUser(non.existing) = %v; want = error", err) } } func TestListUsers(t *testing.T) { errMsgTemplate := "Users() %s = empty; want = non-empty. A common cause would be " + "forgetting to add the 'Firebase Authentication Admin' permission. See " + "instructions in CONTRIBUTING.md" newUsers := map[string]bool{} user := newUserWithParams(t) defer deleteUser(user.UID) newUsers[user.UID] = true user = newUserWithParams(t) defer deleteUser(user.UID) newUsers[user.UID] = true user = newUserWithParams(t) defer deleteUser(user.UID) newUsers[user.UID] = true // test regular iteration count := 0 iter := client.Users(context.Background(), "") for { u, err := iter.Next() if err == iterator.Done { break } else if err != nil { t.Fatal(err) } if _, ok := newUsers[u.UID]; ok { count++ if u.PasswordHash == "" { t.Errorf(errMsgTemplate, "PasswordHash") } if u.PasswordSalt == "" { t.Errorf(errMsgTemplate, "PasswordSalt") } } } if count < 3 { t.Errorf("Users() count = %d; want >= 3", count) } // test paged iteration count = 0 pageCount := 0 iter = client.Users(context.Background(), "") pager := iterator.NewPager(iter, 2, "") for { pageCount++ var users []*auth.ExportedUserRecord nextPageToken, err := pager.NextPage(&users) if err != nil { t.Fatal(err) } count += len(users) if nextPageToken == "" { break } } if count < 3 { t.Errorf("Users() count = %d; want >= 3", count) } if pageCount < 2 { t.Errorf("NewPager() pages = %d; want >= 2", pageCount) } } func TestCreateUser(t *testing.T) { user, err := client.CreateUser(context.Background(), nil) if err != nil { t.Fatalf("CreateUser() = %v; want = nil", err) } defer deleteUser(user.UID) var emptyFactors []*auth.MultiFactorInfo want := auth.UserRecord{ UserInfo: &auth.UserInfo{ UID: user.UID, ProviderID: "firebase", }, UserMetadata: &auth.UserMetadata{ CreationTimestamp: user.UserMetadata.CreationTimestamp, }, TokensValidAfterMillis: user.TokensValidAfterMillis, MultiFactor: &auth.MultiFactorSettings{ EnrolledFactors: emptyFactors, }, } if !reflect.DeepEqual(*user, want) { t.Errorf("CreateUser() = %#v; want = %#v", *user, want) } user, err = client.CreateUser(context.Background(), (&auth.UserToCreate{}).UID(user.UID)) if err == nil || user != nil || !auth.IsUIDAlreadyExists(err) { t.Errorf("CreateUser(existing-uid) = (%#v, %v); want = (nil, error)", user, err) } } func TestCreateUserMFA(t *testing.T) { var tc *auth.UserToCreate = &auth.UserToCreate{} tc.Email("testuser@example.com") tc.EmailVerified(true) tc.MFASettings(auth.MultiFactorSettings{ EnrolledFactors: []*auth.MultiFactorInfo{ { PhoneNumber: "+11234567890", DisplayName: "Phone Number deprecated", FactorID: "phone", }, { Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+19876543210", }, DisplayName: "Phone Number active", FactorID: "phone", }, }, }) user, err := client.CreateUser(context.Background(), tc) if err != nil { t.Fatalf("CreateUser() = %v; want = nil", err) } defer deleteUser(user.UID) var factors []*auth.MultiFactorInfo = []*auth.MultiFactorInfo{ { UID: user.MultiFactor.EnrolledFactors[0].UID, DisplayName: "Phone Number deprecated", FactorID: "phone", Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, PhoneNumber: "+11234567890", EnrollmentTimestamp: user.MultiFactor.EnrolledFactors[0].EnrollmentTimestamp, }, { UID: user.MultiFactor.EnrolledFactors[1].UID, DisplayName: "Phone Number active", FactorID: "phone", Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+19876543210", }, PhoneNumber: "+19876543210", EnrollmentTimestamp: user.MultiFactor.EnrolledFactors[1].EnrollmentTimestamp, }, } want := auth.UserRecord{ EmailVerified: true, UserInfo: &auth.UserInfo{ Email: "testuser@example.com", UID: user.UID, ProviderID: "firebase", }, UserMetadata: &auth.UserMetadata{ CreationTimestamp: user.UserMetadata.CreationTimestamp, }, TokensValidAfterMillis: user.TokensValidAfterMillis, MultiFactor: &auth.MultiFactorSettings{ EnrolledFactors: factors, }, } if !reflect.DeepEqual(*user, want) { t.Errorf("CreateUser() = %#v; want = %#v", *user, want) } } func TestUpdateUser(t *testing.T) { // Creates a new user for testing purposes. The user's uid will be // '$name_$tenRandomChars' and email will be // '$name_$tenRandomChars@example.com'. createTestUser := func(name string) *auth.UserRecord { // TODO(rsgowman: This function could usefully be employed throughout // this file. tenRandomChars := generateRandomAlphaNumericString(10) userRecord, err := client.CreateUser(context.Background(), (&auth.UserToCreate{}). UID(name+"_"+tenRandomChars). DisplayName(name). Email(name+"_"+tenRandomChars+"@example.com"), ) if err != nil { t.Fatal(err) } return userRecord } mapToProviderUIDs := func(userInfos [](*auth.UserInfo)) []string { providerUIDs := []string{} for i := range userInfos { providerUIDs = append(providerUIDs, userInfos[i].UID) } return providerUIDs } mapToProviderIDs := func(userInfos [](*auth.UserInfo)) []string { providerIDs := []string{} for i := range userInfos { providerIDs = append(providerIDs, userInfos[i].ProviderID) } return providerIDs } contains := func(list []string, target string) bool { for i := range list { if list[i] == target { return true } } return false } containsAll := func(list []string, targets []string) bool { for i := range targets { if !contains(list, targets[i]) { return false } } return true } containsNone := func(list []string, targets []string) bool { for i := range targets { if contains(list, targets[i]) { return false } } return true } updateUser := createTestUser("UpdateUser") defer deleteUser(updateUser.UID) t.Run("SimpleUpdate", func(t *testing.T) { uid := randomUID() newEmail := randomEmail(uid) newPhone := randomPhoneNumber() want := auth.UserInfo{ UID: updateUser.UID, Email: newEmail, PhoneNumber: newPhone, DisplayName: "Updated Name", ProviderID: "firebase", PhotoURL: "https://example.com/updated.png", } params := (&auth.UserToUpdate{}). Email(newEmail). PhoneNumber(newPhone). DisplayName("Updated Name"). PhotoURL("https://example.com/updated.png"). EmailVerified(true). Password("newpassowrd") got, err := client.UpdateUser(context.Background(), updateUser.UID, params) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(*got.UserInfo, want) { t.Errorf("UpdateUser().UserInfo = (%#v, %v); want = (%#v, nil)", *got.UserInfo, err, want) } if !got.EmailVerified { t.Error("UpdateUser().EmailVerified = false; want = true") } }) t.Run("LinkFederatedProvider", func(t *testing.T) { // Link user to federated provider googleFederatedUID := "google_uid_" + generateRandomAlphaNumericString(10) params := (&auth.UserToUpdate{}). ProviderToLink((&auth.UserProvider{ ProviderID: "google.com", UID: googleFederatedUID, })) userRecord, err := client.UpdateUser(context.Background(), updateUser.UID, params) if err != nil { t.Fatal(err) } defer func() { // Unlink user from federated provider params = (&auth.UserToUpdate{}).ProvidersToDelete([]string{"google.com"}) userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, params) if err != nil { t.Fatal(err) } }() // Ensure link operation worked as expected providerUIDs := mapToProviderUIDs(userRecord.ProviderUserInfo) providerIDs := mapToProviderIDs(userRecord.ProviderUserInfo) if !contains(providerUIDs, googleFederatedUID) { t.Errorf("UpdateUser().ProviderUserInfo[*].UID = %v; want include %q", providerUIDs, googleFederatedUID) } if !contains(providerIDs, "google.com") { t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID = %v; want include 'google.com'", providerIDs) } }) t.Run("UnlinkFederatedProvider", func(t *testing.T) { // Link user to federated provider googleFederatedUID := "google_uid_" + generateRandomAlphaNumericString(10) params := (&auth.UserToUpdate{}). ProviderToLink((&auth.UserProvider{ ProviderID: "google.com", UID: googleFederatedUID, })) userRecord, err := client.UpdateUser(context.Background(), updateUser.UID, params) if err != nil { t.Fatal(err) } // Unlink user from federated provider params = (&auth.UserToUpdate{}).ProvidersToDelete([]string{"google.com"}) userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, params) if err != nil { t.Fatal(err) } // Ensure unlink operation worked as expected providerUIDs := mapToProviderUIDs(userRecord.ProviderUserInfo) providerIDs := mapToProviderIDs(userRecord.ProviderUserInfo) if contains(providerUIDs, googleFederatedUID) { t.Errorf("UpdateUser().ProviderUserInfo[*].UID = %v; want NOT include %q", providerUIDs, googleFederatedUID) } if contains(providerIDs, "google.com") { t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID = %v; want NOT include 'google.com'", providerIDs) } }) t.Run("UnlinkMultipleProvidersAtOnce", func(t *testing.T) { deletePhoneNumberUser(t, "+15555550001") googleFederatedUID := "google_uid_" + generateRandomAlphaNumericString(10) facebookFederatedUID := "facebook_uid_" + generateRandomAlphaNumericString(10) userRecord, err := client.UpdateUser(context.Background(), updateUser.UID, (&auth.UserToUpdate{}). PhoneNumber("+15555550001"). ProviderToLink((&auth.UserProvider{ ProviderID: "google.com", UID: googleFederatedUID, }))) if err != nil { t.Fatal(err) } userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, (&auth.UserToUpdate{}). ProviderToLink((&auth.UserProvider{ ProviderID: "facebook.com", UID: facebookFederatedUID, }))) if err != nil { t.Fatal(err) } providerUIDs := mapToProviderUIDs(userRecord.ProviderUserInfo) providerIDs := mapToProviderIDs(userRecord.ProviderUserInfo) wantAll := []string{googleFederatedUID, facebookFederatedUID, "+15555550001"} if !containsAll(providerUIDs, wantAll) { t.Errorf("UpdateUser().ProviderUserInfo[*].UID want include all %v; got %v", wantAll, providerUIDs) } wantAll = []string{"google.com", "facebook.com", "phone"} if !containsAll(providerIDs, wantAll) { t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID want include all %v; got %v", wantAll, providerIDs) } userRecord, err = client.UpdateUser(context.Background(), updateUser.UID, (&auth.UserToUpdate{}). ProvidersToDelete([]string{"google.com", "facebook.com", "phone"})) if err != nil { t.Fatal(err) } providerUIDs = mapToProviderUIDs(userRecord.ProviderUserInfo) providerIDs = mapToProviderIDs(userRecord.ProviderUserInfo) notWantAll := []string{googleFederatedUID, facebookFederatedUID, "+15555550001"} if !containsNone(providerUIDs, notWantAll) { t.Errorf("UpdateUser().ProviderUserInfo[*].UID want not include all %v; got %v", notWantAll, providerUIDs) } notWantAll = []string{"google.com", "facebook.com", "phone"} if !containsNone(providerIDs, notWantAll) { t.Errorf("UpdateUser().ProviderUserInfo[*].ProviderID want not include all %v; got %v", notWantAll, providerIDs) } }) t.Run("ErrorsGivenEmptyProvidersToDelete", func(t *testing.T) { userRecord := createTestUser("ErrorWithEmptyProvidersToDeleteUser") defer deleteUser(userRecord.UID) gotUserRecord, err := client.UpdateUser(context.Background(), userRecord.UID, (&auth.UserToUpdate{}).ProvidersToDelete([]string{})) if err == nil || gotUserRecord != nil { t.Errorf("UpdateUser() = (%#v, nil); want (nil, error)", gotUserRecord) } }) } func TestUpdateUserMFA(t *testing.T) { // Creates a new user for testing purposes. The user's uid will be // '$name_$tenRandomChars' and email will be // '$name_$tenRandomChars@example.com'. createTestUserWithMFA := func(name string) *auth.UserRecord { // TODO(rsgowman: This function could usefully be employed throughout // this file. tenRandomChars := generateRandomAlphaNumericString(10) userRecord, err := client.CreateUser(context.Background(), (&auth.UserToCreate{}). Email(name+"_"+tenRandomChars+"@example.com"). EmailVerified(true). MFASettings(auth.MultiFactorSettings{ EnrolledFactors: []*auth.MultiFactorInfo{ { Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, DisplayName: "Phone Number active", FactorID: "phone", }, { PhoneNumber: "+19876543210", DisplayName: "Phone Number deprecated", FactorID: "phone", }, }, }), ) if err != nil { t.Fatal(err) } return userRecord } // Create a test user with MFA settings for testing user := createTestUserWithMFA("UpdateUserMFA") defer deleteUser(user.UID) // Define the updated MFA factors updatedFactors := []*auth.MultiFactorInfo{ { DisplayName: "Phone Number active updated", FactorID: "phone", Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, }, { DisplayName: "Phone Number deprecated updated", FactorID: "phone", PhoneNumber: "+19876543210", }, } // Update the MFA settings params := (&auth.UserToUpdate{}).MFASettings(auth.MultiFactorSettings{ EnrolledFactors: updatedFactors, }) updatedUser, err := client.UpdateUser(context.Background(), user.UID, params) if err != nil { t.Fatal(err) } want := auth.UserRecord{ EmailVerified: true, UserInfo: &auth.UserInfo{ Email: updatedUser.Email, UID: updatedUser.UID, ProviderID: "firebase", }, UserMetadata: &auth.UserMetadata{ CreationTimestamp: updatedUser.UserMetadata.CreationTimestamp, }, TokensValidAfterMillis: updatedUser.TokensValidAfterMillis, MultiFactor: &auth.MultiFactorSettings{ EnrolledFactors: []*auth.MultiFactorInfo{ { UID: updatedUser.MultiFactor.EnrolledFactors[0].UID, Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+11234567890", }, PhoneNumber: "+11234567890", DisplayName: "Phone Number active updated", FactorID: "phone", EnrollmentTimestamp: updatedUser.MultiFactor.EnrolledFactors[0].EnrollmentTimestamp, }, { UID: updatedUser.MultiFactor.EnrolledFactors[1].UID, Phone: &auth.PhoneMultiFactorInfo{ PhoneNumber: "+19876543210", }, PhoneNumber: "+19876543210", DisplayName: "Phone Number deprecated updated", FactorID: "phone", EnrollmentTimestamp: updatedUser.MultiFactor.EnrolledFactors[1].EnrollmentTimestamp, }, }, }, } // Compare the updated user with the expected user record if !reflect.DeepEqual(*updatedUser, want) { t.Errorf("UpdateUser() = %#v; want = %#v", *updatedUser, want) } } func TestDisableUser(t *testing.T) { user := newUserWithParams(t) defer deleteUser(user.UID) if user.Disabled { t.Errorf("NewUser.Disabled = true; want = false") } params := (&auth.UserToUpdate{}).Disabled(true) got, err := client.UpdateUser(context.Background(), user.UID, params) if err != nil { t.Fatal(err) } if !got.Disabled { t.Errorf("UpdateUser().Disabled = false; want = true") } params = (&auth.UserToUpdate{}).Disabled(false) got, err = client.UpdateUser(context.Background(), user.UID, params) if err != nil { t.Fatal(err) } if got.Disabled { t.Errorf("UpdateUser().Disabled = true; want = false") } } func TestRemovePhonePhotoName(t *testing.T) { user := newUserWithParams(t) defer deleteUser(user.UID) if user.PhoneNumber == "" { t.Errorf("NewUser.PhoneNumber = empty; want = non-empty") } if len(user.ProviderUserInfo) != 2 { t.Errorf("NewUser.ProviderUserInfo = %d; want = 2", len(user.ProviderUserInfo)) } if user.PhotoURL == "" { t.Errorf("NewUser.PhotoURL = empty; want = non-empty") } if user.DisplayName == "" { t.Errorf("NewUser.DisplayName = empty; want = non-empty") } params := (&auth.UserToUpdate{}).PhoneNumber("").PhotoURL("").DisplayName("") got, err := client.UpdateUser(context.Background(), user.UID, params) if err != nil { t.Fatal(err) } if got.PhoneNumber != "" { t.Errorf("UpdateUser().PhoneNumber = %q; want: %q", got.PhoneNumber, "") } if len(got.ProviderUserInfo) != 1 { t.Errorf("UpdateUser().ProviderUserInfo = %d, want = 1", len(got.ProviderUserInfo)) } if got.DisplayName != "" { t.Errorf("UpdateUser().DisplayName = %q; want =%q", got.DisplayName, "") } if got.PhotoURL != "" { t.Errorf("UpdateUser().PhotoURL = %q; want = %q", got.PhotoURL, "") } } func TestSetCustomClaims(t *testing.T) { user := newUserWithParams(t) defer deleteUser(user.UID) if user.CustomClaims != nil { t.Fatalf("NewUser.CustomClaims = %#v; want = nil", user.CustomClaims) } setAndVerifyClaims := func(claims map[string]interface{}) { if err := client.SetCustomUserClaims(context.Background(), user.UID, claims); err != nil { t.Fatalf("SetCustomUserClaims() = %v; want = nil", err) } got, err := client.GetUser(context.Background(), user.UID) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(got.CustomClaims, claims) { t.Errorf("SetCustomUserClaims().CustomClaims = %#v; want = %#v", got.CustomClaims, claims) } } setAndVerifyClaims(map[string]interface{}{ "admin": true, "package": "gold", }) setAndVerifyClaims(map[string]interface{}{ "admin": false, "subscription": "guest", }) setAndVerifyClaims(nil) } func TestDeleteUser(t *testing.T) { user := newUserWithParams(t) if err := client.DeleteUser(context.Background(), user.UID); err != nil { t.Fatalf("DeleteUser() = %v; want = nil", err) } got, err := client.GetUser(context.Background(), user.UID) if err == nil || got != nil || !auth.IsUserNotFound(err) { t.Errorf("GetUser(deleted) = (%#v, %v); want = (nil, error)", got, err) } } func TestDeleteUsers(t *testing.T) { // Deletes users slowly. There's currently a 1qps limitation on this API. // Without this helper, the integration tests occasionally hit that limit // and fail. // // TODO(rsgowman): Remove this function when/if the 1qps limitation is // relaxed. slowDeleteUsers := func(ctx context.Context, uids []string) (*auth.DeleteUsersResult, error) { time.Sleep(1 * time.Second) return client.DeleteUsers(ctx, uids) } // Ensures the specified users don't exist. Expected to be called after // deleting the users to ensure the delete method worked. ensureUsersNotFound := func(t *testing.T, uids []string) { identifiers := []auth.UserIdentifier{} for i := range uids { identifiers = append(identifiers, auth.UIDIdentifier{UID: uids[i]}) } getUsersResult, err := client.GetUsers(context.Background(), identifiers) if err != nil { t.Errorf("GetUsers(notfound_ids) error %v; want nil", err) return } if len(getUsersResult.NotFound) != len(uids) { t.Errorf("len(GetUsers(notfound_ids).NotFound) = %d; want %d", len(getUsersResult.NotFound), len(uids)) return } sort.Strings(uids) notFoundUids := []string{} for i := range getUsersResult.NotFound { notFoundUids = append(notFoundUids, getUsersResult.NotFound[i].(auth.UIDIdentifier).UID) } sort.Strings(notFoundUids) for i := range uids { if notFoundUids[i] != uids[i] { t.Errorf("GetUsers(deleted_ids).NotFound[%d] = %s; want %s", i, notFoundUids[i], uids[i]) } } } t.Run("deletes users", func(t *testing.T) { uids := []string{ newUserWithParams(t).UID, newUserWithParams(t).UID, newUserWithParams(t).UID, } result, err := slowDeleteUsers(context.Background(), uids) if err != nil { t.Fatalf("DeleteUsers([valid_ids]) error %v; want nil", err) } if result.SuccessCount != 3 { t.Errorf("DeleteUsers([valid_ids]).SuccessCount = %d; want 3", result.SuccessCount) } if result.FailureCount != 0 { t.Errorf("DeleteUsers([valid_ids]).FailureCount = %d; want 0", result.FailureCount) } if len(result.Errors) != 0 { t.Errorf("len(DeleteUsers([valid_ids]).Errors) = %d; want 0", len(result.Errors)) } ensureUsersNotFound(t, uids) }) t.Run("deletes users that exist even when non-existing users also specified", func(t *testing.T) { uids := []string{newUserWithParams(t).UID, "uid-that-doesnt-exist"} result, err := slowDeleteUsers(context.Background(), uids) if err != nil { t.Fatalf("DeleteUsers(uids) error %v; want nil", err) } if result.SuccessCount != 2 { t.Errorf("DeleteUsers(uids).SuccessCount = %d; want 2", result.SuccessCount) } if result.FailureCount != 0 { t.Errorf("DeleteUsers(uids).FailureCount = %d; want 0", result.FailureCount) } if len(result.Errors) != 0 { t.Errorf("len(DeleteUsers(uids).Errors) = %d; want 0", len(result.Errors)) } ensureUsersNotFound(t, uids) }) t.Run("is idempotent", func(t *testing.T) { deleteUserAndEnsureSuccess := func(t *testing.T, uids []string) { result, err := slowDeleteUsers(context.Background(), uids) if err != nil { t.Fatalf("DeleteUsers(uids) error %v; want nil", err) } if result.SuccessCount != 1 { t.Errorf("DeleteUsers(uids).SuccessCount = %d; want 1", result.SuccessCount) } if result.FailureCount != 0 { t.Errorf("DeleteUsers(uids).FailureCount = %d; want 0", result.FailureCount) } if len(result.Errors) != 0 { t.Errorf("len(DeleteUsers(uids).Errors) = %d; want 0", len(result.Errors)) } } uids := []string{newUserWithParams(t).UID} deleteUserAndEnsureSuccess(t, uids) // Delete the user again, ensuring that everything still counts as a success. deleteUserAndEnsureSuccess(t, uids) }) } func TestImportUsers(t *testing.T) { uid := randomUID() email := randomEmail(uid) user := (&auth.UserToImport{}).UID(uid).Email(email) result, err := client.ImportUsers(context.Background(), []*auth.UserToImport{user}) if err != nil { t.Fatal(err) } defer deleteUser(uid) if result.SuccessCount != 1 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 1, FailureCount: 0}", result) } savedUser, err := client.GetUser(context.Background(), uid) if err != nil { t.Fatal(err) } if savedUser.Email != email { t.Errorf("GetUser(imported) = %q; want = %q", savedUser.Email, email) } } func TestImportUsersWithPassword(t *testing.T) { scrypt, passwordHash, err := newScryptHash() if err != nil { t.Fatalf("newScryptHash() = %v", err) } uid := randomUID() email := randomEmail(uid) user := (&auth.UserToImport{}). UID(uid). Email(email). PasswordHash(passwordHash). PasswordSalt([]byte("NaCl")) result, err := client.ImportUsers(context.Background(), []*auth.UserToImport{user}, auth.WithHash(scrypt)) if err != nil { t.Fatal(err) } defer deleteUser(uid) if result.SuccessCount != 1 || result.FailureCount != 0 { t.Errorf("ImportUsers() = %#v; want = {SuccessCount: 1, FailureCount: 0}", result) } savedUser, err := client.GetUser(context.Background(), uid) if err != nil { t.Fatal(err) } if savedUser.Email != email { t.Errorf("GetUser(imported) = %q; want = %q", savedUser.Email, email) } idToken, err := signInWithPassword(email, "password") if err != nil { t.Fatal(err) } if idToken == "" { t.Errorf("ID Token = empty; want = non-empty") } } func newScryptHash() (*hash.Scrypt, []byte, error) { const ( rawScryptKey = "jxspr8Ki0RYycVU8zykbdLGjFQ3McFUH0uiiTvC8pVMXAn210wjLNmdZJzxUECKbm0QsEmYUSDzZvpjeJ9WmXA==" rawPasswordHash = "V358E8LdWJXAO7muq0CufVpEOXaj8aFiC7T/rcaGieN04q/ZPJ08WhJEHGjj9lz/2TT+/86N5VjVoc5DdBhBiw==" rawSeparator = "Bw==" ) scryptKey, err := base64.StdEncoding.DecodeString(rawScryptKey) if err != nil { return nil, nil, err } saltSeparator, err := base64.StdEncoding.DecodeString(rawSeparator) if err != nil { return nil, nil, err } passwordHash, err := base64.StdEncoding.DecodeString(rawPasswordHash) if err != nil { return nil, nil, err } scrypt := hash.Scrypt{ Key: scryptKey, SaltSeparator: saltSeparator, Rounds: 8, MemoryCost: 14, } return &scrypt, passwordHash, nil } func TestSessionCookie(t *testing.T) { uid := "cookieuser" customToken, err := client.CustomToken(context.Background(), uid) if err != nil { t.Fatal(err) } idToken, err := signInWithCustomToken(customToken) if err != nil { t.Fatal(err) } defer deleteUser(uid) cookie, err := client.SessionCookie(context.Background(), idToken, 10*time.Minute) if err != nil { t.Fatal(err) } if cookie == "" { t.Errorf("SessionCookie() = %q; want = non-empty", cookie) } vt, err := client.VerifySessionCookieAndCheckRevoked(context.Background(), cookie) if err != nil { t.Fatal(err) } if vt.UID != uid { t.Errorf("VerifySessionCookieAndCheckRevoked() UID = %q; want = %q", vt.UID, uid) } // The backend stores the validSince property in seconds since the epoch. // The issuedAt property of the token is also in seconds. If a token was // issued, and then in the same second tokens were revoked, the token will // have the same timestamp as the tokensValidAfterMillis, and will therefore // not be considered revoked. Hence we wait one second before revoking. time.Sleep(time.Second) if err = client.RevokeRefreshTokens(context.Background(), uid); err != nil { t.Fatal(err) } vt, err = client.VerifySessionCookieAndCheckRevoked(context.Background(), cookie) if vt != nil || err == nil || !auth.IsSessionCookieRevoked(err) { t.Errorf("tok, err := VerifySessionCookieAndCheckRevoked() = (%v, %v); want = (nil, session-cookie-revoked)", vt, err) } // Does not return error for revoked token. if _, err = client.VerifySessionCookie(context.Background(), cookie); err != nil { t.Errorf("VerifySessionCookie() = %v; want = nil", err) } } func TestEmailVerificationLink(t *testing.T) { user := newUserWithParams(t) defer deleteUser(user.UID) link, err := client.EmailVerificationLinkWithSettings(context.Background(), user.Email, &auth.ActionCodeSettings{ URL: continueURL, HandleCodeInApp: false, }) if err != nil { t.Fatal(err) } parsed, err := url.ParseRequestURI(link) if err != nil { t.Fatal(err) } query := parsed.Query() if got := query.Get(continueURLKey); got != continueURL { t.Errorf("EmailVerificationLinkWithSettings() %s = %q; want = %q", continueURLKey, got, continueURL) } const verifyEmail = "verifyEmail" if got := query.Get(modeKey); got != verifyEmail { t.Errorf("EmailVerificationLinkWithSettings() %s = %q; want = %q", modeKey, got, verifyEmail) } } func TestPasswordResetLink(t *testing.T) { user := newUserWithParams(t) defer deleteUser(user.UID) link, err := client.PasswordResetLinkWithSettings(context.Background(), user.Email, &auth.ActionCodeSettings{ URL: continueURL, HandleCodeInApp: false, }) if err != nil { t.Fatal(err) } parsed, err := url.ParseRequestURI(link) if err != nil { t.Fatal(err) } query := parsed.Query() if got := query.Get(continueURLKey); got != continueURL { t.Errorf("PasswordResetLinkWithSettings() %s = %q; want = %q", continueURLKey, got, continueURL) } oobCode := query.Get(oobCodeKey) if err := resetPassword(user.Email, "password", "newPassword", oobCode); err != nil { t.Fatalf("PasswordResetLinkWithSettings() reset = %v; want = nil", err) } // Password reset also verifies the user's email user, err = client.GetUser(context.Background(), user.UID) if err != nil { t.Fatalf("GetUser() = %v; want = nil", err) } if !user.EmailVerified { t.Error("PasswordResetLinkWithSettings() EmailVerified = false; want = true") } } func TestEmailSignInLink(t *testing.T) { user := newUserWithParams(t) defer deleteUser(user.UID) link, err := client.EmailSignInLink(context.Background(), user.Email, &auth.ActionCodeSettings{ URL: continueURL, HandleCodeInApp: false, }) if err != nil { t.Fatal(err) } parsed, err := url.ParseRequestURI(link) if err != nil { t.Fatal(err) } query := parsed.Query() if got := query.Get(continueURLKey); got != continueURL { t.Errorf("EmailSignInLink() %s = %q; want = %q", continueURLKey, got, continueURL) } oobCode := query.Get(oobCodeKey) idToken, err := signInWithEmailLink(user.Email, oobCode) if err != nil { t.Fatalf("EmailSignInLink() signIn = %v; want = nil", err) } if idToken == "" { t.Errorf("ID Token = empty; want = non-empty") } // Signing in with email link also verifies the user's email user, err = client.GetUser(context.Background(), user.UID) if err != nil { t.Fatalf("GetUser() = %v; want = nil", err) } if !user.EmailVerified { t.Error("EmailSignInLink() EmailVerified = false; want = true") } } func resetPassword(email, oldPassword, newPassword, oobCode string) error { req := map[string]interface{}{ "email": email, "oldPassword": oldPassword, "newPassword": newPassword, "oobCode": oobCode, } reqBytes, err := json.Marshal(req) if err != nil { return err } _, err = postRequest(fmt.Sprintf(resetPasswordURL, apiKey), reqBytes) return err } func signInWithEmailLink(email, oobCode string) (string, error) { req := map[string]interface{}{ "email": email, "oobCode": oobCode, } reqBytes, err := json.Marshal(req) if err != nil { return "", err } b, err := postRequest(fmt.Sprintf(emailLinkSignInURL, apiKey), reqBytes) if err != nil { return "", err } var parsed struct { IDToken string `json:"idToken"` } if err := json.Unmarshal(b, &parsed); err != nil { return "", err } return parsed.IDToken, nil } var seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) func randomUID() string { return generateRandomAlphaNumericString(32) } func randomPhoneNumber() string { return "+1" + generateRandomNumericString(10) } func randomEmail(uid string) string { return strings.ToLower(fmt.Sprintf("%s@example.%s.com", uid[:12], uid[12:])) } func generateRandomNumericString(length int) string { digits := []rune("0123456789") return generateRandomString(length, digits) } func generateRandomAlphaNumericString(length int) string { letters := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") return generateRandomString(length, letters) } func generateRandomString(length int, runes []rune) string { b := make([]rune, length) for i := range b { b[i] = runes[seededRand.Intn(len(runes))] } return string(b) } func newUserWithParams(t *testing.T) *auth.UserRecord { uid := randomUID() email := randomEmail(uid) phone := randomPhoneNumber() params := (&auth.UserToCreate{}). UID(uid). Email(email). PhoneNumber(phone). DisplayName("Random User"). PhotoURL("https://example.com/photo.png"). Password("password") user, err := client.CreateUser(context.Background(), params) if err != nil { t.Fatal(err) } return user } // Helper to import a user and return its UserRecord. Upon error, exits via // t.Fatalf. `uid` must match the UID set on the `userToImport` parameter. func importUser(t *testing.T, uid string, userToImport *auth.UserToImport) *auth.UserRecord { userImportResult, err := client.ImportUsers( context.Background(), [](*auth.UserToImport){userToImport}) if err != nil { t.Fatalf("Unable to import user %v (uid %v): %v", *userToImport, uid, err) } if userImportResult.FailureCount > 0 { t.Fatalf("Unable to import user %v (uid %v): %v", *userToImport, uid, userImportResult.Errors[0].Reason) } if userImportResult.SuccessCount != 1 { t.Fatalf("Import didn't fail, but it didn't succeed either?") } userRecord, err := client.GetUser(context.Background(), uid) if err != nil { t.Fatalf("GetUser(%s) for imported user failed: %v", uid, err) } return userRecord } // Helper function that deletes the user with the specified phone number if it // exists. // TODO(rsgowman): This function was ported from node.js port; a number of tests // there use this, but haven't been ported to go yet. Do so. func deletePhoneNumberUser(t *testing.T, phoneNumber string) { userRecord, err := client.GetUserByPhoneNumber(context.Background(), phoneNumber) if err != nil { if auth.IsUserNotFound(err) { // User already doesn't exist. return } t.Fatal(err) } if err = client.DeleteUser(context.Background(), userRecord.UID); err != nil { t.Fatal(err) } } golang-google-firebase-go-4.18.0/integration/db/000077500000000000000000000000001505612111400214125ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/db/db_test.go000066400000000000000000000457501505612111400234000ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package db contains integration tests for the firebase.google.com/go/db package. package db import ( "bytes" "context" "encoding/json" "flag" "fmt" "io/ioutil" "log" "net/http" "os" "reflect" "testing" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/db" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/integration/internal" ) var client *db.Client var aoClient *db.Client var guestClient *db.Client var ref *db.Ref var users *db.Ref var dinos *db.Ref var testData map[string]interface{} var parsedTestData map[string]Dinosaur const permDenied = "http error status: 401; reason: Permission denied" func TestMain(m *testing.M) { flag.Parse() if testing.Short() { log.Println("skipping database integration tests in short mode.") os.Exit(0) } pid, err := internal.ProjectID() if err != nil { log.Fatalln(err) } client, err = initClient(pid) if err != nil { log.Fatalln(err) } aoClient, err = initOverrideClient(pid) if err != nil { log.Fatalln(err) } guestClient, err = initGuestClient(pid) if err != nil { log.Fatalln(err) } ref = client.NewRef("_adminsdk/go/dinodb") dinos = ref.Child("dinosaurs") users = ref.Parent().Child("users") initRules() initData() os.Exit(m.Run()) } func initClient(pid string) (*db.Client, error) { ctx := context.Background() url, err := getDatabaseURL() if err != nil { return nil, err } app, err := internal.NewTestApp(ctx, &firebase.Config{ DatabaseURL: url, }) if err != nil { return nil, err } return app.Database(ctx) } func initOverrideClient(pid string) (*db.Client, error) { ctx := context.Background() ao := map[string]interface{}{"uid": "user1"} url, err := getDatabaseURL() if err != nil { return nil, err } app, err := internal.NewTestApp(ctx, &firebase.Config{ DatabaseURL: url, AuthOverride: &ao, }) if err != nil { return nil, err } return app.Database(ctx) } func initGuestClient(pid string) (*db.Client, error) { ctx := context.Background() var nullMap map[string]interface{} url, err := getDatabaseURL() if err != nil { return nil, err } app, err := internal.NewTestApp(ctx, &firebase.Config{ DatabaseURL: url, AuthOverride: &nullMap, }) if err != nil { return nil, err } return app.Database(ctx) } func initRules() { b, err := ioutil.ReadFile(internal.Resource("dinosaurs_index.json")) if err != nil { log.Fatalln(err) } url, err := getDatabaseRulesURL() if err != nil { log.Fatalln(err) } req, err := http.NewRequest("PUT", url, bytes.NewBuffer(b)) if err != nil { log.Fatalln(err) } req.Header.Set("Content-Type", "application/json") hc, err := internal.NewHTTPClient(context.Background()) if err != nil { log.Fatalln(err) } resp, err := hc.Do(req) if err != nil { log.Fatalln(err) } defer resp.Body.Close() b, err = ioutil.ReadAll(resp.Body) if err != nil { log.Fatalln(err) } else if resp.StatusCode != http.StatusOK { log.Fatalln("failed to update rules:", string(b)) } } func initData() { b, err := ioutil.ReadFile(internal.Resource("dinosaurs.json")) if err != nil { log.Fatalln(err) } if err = json.Unmarshal(b, &testData); err != nil { log.Fatalln(err) } b, err = json.Marshal(testData["dinosaurs"]) if err != nil { log.Fatalln(err) } if err = json.Unmarshal(b, &parsedTestData); err != nil { log.Fatalln(err) } if err = ref.Set(context.Background(), testData); err != nil { log.Fatalln(err) } } func TestRef(t *testing.T) { if ref.Key != "dinodb" { t.Errorf("Key = %q; want = %q", ref.Key, "dinodb") } if ref.Path != "/_adminsdk/go/dinodb" { t.Errorf("Path = %q; want = %q", ref.Path, "/_adminsdk/go/dinodb") } } func TestChild(t *testing.T) { c := ref.Child("dinosaurs") if c.Key != "dinosaurs" { t.Errorf("Key = %q; want = %q", c.Key, "dinosaurs") } if c.Path != "/_adminsdk/go/dinodb/dinosaurs" { t.Errorf("Path = %q; want = %q", c.Path, "/_adminsdk/go/dinodb/dinosaurs") } } func TestParent(t *testing.T) { p := ref.Parent() if p.Key != "go" { t.Errorf("Key = %q; want = %q", p.Key, "go") } if p.Path != "/_adminsdk/go" { t.Errorf("Path = %q; want = %q", p.Path, "/_adminsdk/go") } } func TestGet(t *testing.T) { var m map[string]interface{} if err := ref.Get(context.Background(), &m); err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData, m) { t.Errorf("Get() = %v; want = %v", m, testData) } } func TestGetWithETag(t *testing.T) { var m map[string]interface{} etag, err := ref.GetWithETag(context.Background(), &m) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData, m) { t.Errorf("GetWithETag() = %v; want = %v", m, testData) } if etag == "" { t.Errorf("GetWithETag() = \"\"; want non-empty") } } func TestGetShallow(t *testing.T) { var m map[string]interface{} if err := ref.GetShallow(context.Background(), &m); err != nil { t.Fatal(err) } want := map[string]interface{}{} for k := range testData { want[k] = true } if !reflect.DeepEqual(want, m) { t.Errorf("GetShallow() = %v; want = %v", m, want) } } func TestGetIfChanged(t *testing.T) { var m map[string]interface{} ok, etag, err := ref.GetIfChanged(context.Background(), "wrong-etag", &m) if err != nil { t.Fatal(err) } if !ok || etag == "" { t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag, true, "non-empty") } if !reflect.DeepEqual(testData, m) { t.Errorf("GetWithETag() = %v; want = %v", m, testData) } var m2 map[string]interface{} ok, etag2, err := ref.GetIfChanged(context.Background(), etag, &m2) if err != nil { t.Fatal(err) } if ok || etag != etag2 { t.Errorf("GetIfChanged() = (%v, %q); want = (%v, %q)", ok, etag2, false, etag) } if len(m2) != 0 { t.Errorf("GetWithETag() = %v; want empty", m) } } func TestGetChildValue(t *testing.T) { c := ref.Child("dinosaurs") var m map[string]interface{} if err := c.Get(context.Background(), &m); err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData["dinosaurs"], m) { t.Errorf("Get() = %v; want = %v", m, testData["dinosaurs"]) } } func TestGetGrandChildValue(t *testing.T) { c := ref.Child("dinosaurs/lambeosaurus") var got Dinosaur if err := c.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := parsedTestData["lambeosaurus"] if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } } func TestGetNonExistingChild(t *testing.T) { c := ref.Child("non_existing") var i interface{} if err := c.Get(context.Background(), &i); err != nil { t.Fatal(err) } if i != nil { t.Errorf("Get() = %v; want nil", i) } } func TestPush(t *testing.T) { u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } if u.Path != "/_adminsdk/go/users/"+u.Key { t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) } var i interface{} if err := u.Get(context.Background(), &i); err != nil { t.Fatal(err) } if i != "" { t.Errorf("Get() = %v; want empty string", i) } } func TestPushWithValue(t *testing.T) { want := User{"Luis Alvarez", 1911} u, err := users.Push(context.Background(), &want) if err != nil { t.Fatal(err) } if u.Path != "/_adminsdk/go/users/"+u.Key { t.Errorf("Push() = %q; want = %q", u.Path, "/_adminsdk/go/users/"+u.Key) } var got User if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if want != got { t.Errorf("Get() = %v; want = %v", got, want) } } func TestSetPrimitiveValue(t *testing.T) { u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } if err := u.Set(context.Background(), "value"); err != nil { t.Fatal(err) } var got string if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "value" { t.Errorf("Get() = %q; want = %q", got, "value") } } func TestSetComplexValue(t *testing.T) { u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } want := User{"Mary Anning", 1799} if err := u.Set(context.Background(), &want); err != nil { t.Fatal(err) } var got User if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != want { t.Errorf("Get() = %v; want = %v", got, want) } } func TestUpdateChildren(t *testing.T) { u, err := users.Push(context.Background(), nil) if err != nil { t.Fatal(err) } want := map[string]interface{}{ "name": "Robert Bakker", "since": float64(1945), } if err := u.Update(context.Background(), want); err != nil { t.Fatal(err) } var got map[string]interface{} if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } } func TestUpdateChildrenWithExistingValue(t *testing.T) { u, err := users.Push(context.Background(), map[string]interface{}{ "name": "Edwin Colbert", "since": float64(1900), }) if err != nil { t.Fatal(err) } update := map[string]interface{}{"since": float64(1905)} if err := u.Update(context.Background(), update); err != nil { t.Fatal(err) } var got map[string]interface{} if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := map[string]interface{}{ "name": "Edwin Colbert", "since": float64(1905), } if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } } func TestUpdateNestedChildren(t *testing.T) { edward, err := users.Push(context.Background(), map[string]interface{}{ "name": "Edward Cope", "since": float64(1800), }) if err != nil { t.Fatal(err) } jack, err := users.Push(context.Background(), map[string]interface{}{ "name": "Jack Horner", "since": float64(1940), }) if err != nil { t.Fatal(err) } delta := map[string]interface{}{ fmt.Sprintf("%s/since", edward.Key): 1840, fmt.Sprintf("%s/since", jack.Key): 1946, } if err := users.Update(context.Background(), delta); err != nil { t.Fatal(err) } var got map[string]interface{} if err := edward.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := map[string]interface{}{"name": "Edward Cope", "since": float64(1840)} if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } if err := jack.Get(context.Background(), &got); err != nil { t.Fatal(err) } want = map[string]interface{}{"name": "Jack Horner", "since": float64(1946)} if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } } func TestSetIfChanged(t *testing.T) { edward, err := users.Push(context.Background(), &User{"Edward Cope", 1800}) if err != nil { t.Fatal(err) } update := User{"Jack Horner", 1940} ok, err := edward.SetIfUnchanged(context.Background(), "invalid-etag", &update) if err != nil { t.Fatal(err) } if ok { t.Errorf("SetIfUnchanged() = %v; want = %v", ok, false) } var u User etag, err := edward.GetWithETag(context.Background(), &u) if err != nil { t.Fatal(err) } ok, err = edward.SetIfUnchanged(context.Background(), etag, &update) if err != nil { t.Fatal(err) } if !ok { t.Errorf("SetIfUnchanged() = %v; want = %v", ok, true) } if err := edward.Get(context.Background(), &u); err != nil { t.Fatal(err) } if !reflect.DeepEqual(update, u) { t.Errorf("Get() = %v; want = %v", u, update) } } func TestTransaction(t *testing.T) { u, err := users.Push(context.Background(), &User{Name: "Richard"}) if err != nil { t.Fatal(err) } fn := func(t db.TransactionNode) (interface{}, error) { var user User if err := t.Unmarshal(&user); err != nil { return nil, err } user.Name = "Richard Owen" user.Since = 1804 return &user, nil } if err := u.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } var got User if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } want := User{"Richard Owen", 1804} if !reflect.DeepEqual(want, got) { t.Errorf("Get() = %v; want = %v", got, want) } } func TestTransactionScalar(t *testing.T) { cnt := users.Child("count") if err := cnt.Set(context.Background(), 42); err != nil { t.Fatal(err) } fn := func(t db.TransactionNode) (interface{}, error) { var snap float64 if err := t.Unmarshal(&snap); err != nil { return nil, err } return snap + 1, nil } if err := cnt.Transaction(context.Background(), fn); err != nil { t.Fatal(err) } var got float64 if err := cnt.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != 43.0 { t.Errorf("Get() = %v; want = %v", got, 43.0) } } func TestDelete(t *testing.T) { u, err := users.Push(context.Background(), "foo") if err != nil { t.Fatal(err) } var got string if err := u.Get(context.Background(), &got); err != nil { t.Fatal(err) } if got != "foo" { t.Errorf("Get() = %q; want = %q", got, "foo") } if err := u.Delete(context.Background()); err != nil { t.Fatal(err) } var got2 string if err := u.Get(context.Background(), &got2); err != nil { t.Fatal(err) } if got2 != "" { t.Errorf("Get() = %q; want = %q", got2, "") } } func TestNoAccess(t *testing.T) { r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/admin")) var got string if err := r.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else { if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } if err := r.Set(context.Background(), "update"); err == nil { t.Errorf("Set() = nil; want = error") } else { if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } } func TestReadAccess(t *testing.T) { r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user2")) var got string if err := r.Get(context.Background(), &got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") } err := r.Set(context.Background(), "update") if err == nil { t.Fatalf("Set() = nil; want = error") } if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } func TestReadWriteAccess(t *testing.T) { r := aoClient.NewRef(protectedRef(t, "_adminsdk/go/protected/user1")) var got string if err := r.Get(context.Background(), &got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") } if err := r.Set(context.Background(), "update"); err != nil { t.Errorf("Set() = %v; want = nil", err) } } func TestQueryAccess(t *testing.T) { r := aoClient.NewRef("_adminsdk/go/protected") got := make(map[string]interface{}) err := r.OrderByKey().LimitToFirst(2).Get(context.Background(), &got) if err == nil { t.Fatalf("OrderByQuery() = nil; want = error") } if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } func TestGuestAccess(t *testing.T) { r := guestClient.NewRef(protectedRef(t, "_adminsdk/go/public")) var got string if err := r.Get(context.Background(), &got); err != nil || got != "test" { t.Errorf("Get() = (%q, %v); want = (%q, nil)", got, err, "test") } if err := r.Set(context.Background(), "update"); err == nil { t.Errorf("Set() = nil; want = error") } else { if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } got = "" r = guestClient.NewRef("_adminsdk/go") if err := r.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else { if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } c := r.Child("protected/user2") if err := c.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else { if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } c = r.Child("admin") if err := c.Get(context.Background(), &got); err == nil || got != "" { t.Errorf("Get() = (%q, %v); want = (empty, error)", got, err) } else { if err.Error() != permDenied { t.Errorf("Error = %q; want = %q", err.Error(), permDenied) } if !errorutils.IsUnauthenticated(err) { t.Errorf("IsUnauthenticated() = false; want = true") } } } func TestWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var m map[string]interface{} if err := ref.Get(ctx, &m); err != nil { t.Fatal(err) } if !reflect.DeepEqual(testData, m) { t.Errorf("Get() = %v; want = %v", m, testData) } cancel() m = nil if err := ref.Get(ctx, &m); len(m) != 0 || err == nil { t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) } } func protectedRef(t *testing.T, p string) string { r := client.NewRef(p) if err := r.Set(context.Background(), "test"); err != nil { t.Fatal(err) } return p } type Dinosaur struct { Appeared float64 `json:"appeared"` Height float64 `json:"height"` Length float64 `json:"length"` Order string `json:"order"` Vanished float64 `json:"vanished"` Weight int `json:"weight"` Ratings Ratings `json:"ratings"` } type Ratings struct { Pos int `json:"pos"` } type User struct { Name string `json:"name"` Since int `json:"since"` } func getDatabaseRulesURL() (string, error) { emulatorHost := os.Getenv("FIREBASE_DATABASE_EMULATOR_HOST") if emulatorHost != "" { return fmt.Sprintf("http://%s/.settings/rules.json?ns=%s", emulatorHost, os.Getenv("FIREBASE_DATABASE_EMULATOR_NAMESPACE")), nil } prodURL, err := getProductionURL() if err != nil { return "", err } return fmt.Sprintf("%s/.settings/rules.json", prodURL), nil } func getDatabaseURL() (string, error) { emulatorHost := os.Getenv("FIREBASE_DATABASE_EMULATOR_HOST") if emulatorHost != "" { return fmt.Sprintf("%s?ns=%s", emulatorHost, os.Getenv("FIREBASE_DATABASE_EMULATOR_NAMESPACE")), nil } return getProductionURL() } func getProductionURL() (string, error) { pid, err := internal.ProjectID() if err != nil { return "", err } return fmt.Sprintf("https://%s.firebaseio.com", pid), nil } golang-google-firebase-go-4.18.0/integration/db/query_test.go000066400000000000000000000155041505612111400241520ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package db import ( "context" "reflect" "testing" "firebase.google.com/go/v4/db" ) var heightSorted = []string{ "linhenykus", "pterodactyl", "lambeosaurus", "triceratops", "stegosaurus", "bruhathkayosaurus", } func TestLimitToFirst(t *testing.T) { for _, tc := range []int{2, 10} { results, err := dinos.OrderByChild("height").LimitToFirst(tc).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } wl := min(tc, len(heightSorted)) want := heightSorted[:wl] if len(results) != wl { t.Errorf("LimitToFirst() = %d; want = %d", len(results), wl) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } } func TestLimitToLast(t *testing.T) { for _, tc := range []int{2, 10} { results, err := dinos.OrderByChild("height").LimitToLast(tc).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } wl := min(tc, len(heightSorted)) want := heightSorted[len(heightSorted)-wl:] if len(results) != wl { t.Errorf("LimitToLast() = %d; want = %d", len(results), wl) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } } func TestStartAt(t *testing.T) { results, err := dinos.OrderByChild("height").StartAt(3.5).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-2:] if len(results) != len(want) { t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } func TestEndAt(t *testing.T) { results, err := dinos.OrderByChild("height").EndAt(3.5).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := heightSorted[:4] if len(results) != len(want) { t.Errorf("StartAt() = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } func TestStartAndEndAt(t *testing.T) { results, err := dinos.OrderByChild("height").StartAt(2.5).EndAt(5).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] if len(results) != len(want) { t.Errorf("StartAt(), EndAt() = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } func TestEqualTo(t *testing.T) { results, err := dinos.OrderByChild("height").EqualTo(0.6).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := heightSorted[:2] if len(results) != len(want) { t.Errorf("EqualTo() = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } func TestOrderByNestedChild(t *testing.T) { results, err := dinos.OrderByChild("ratings/pos").StartAt(4).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := []string{"pterodactyl", "stegosaurus", "triceratops"} if len(results) != len(want) { t.Errorf("OrderByChild(ratings/pos) = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } func TestOrderByKey(t *testing.T) { results, err := dinos.OrderByKey().LimitToFirst(2).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := []string{"bruhathkayosaurus", "lambeosaurus"} if len(results) != len(want) { t.Errorf("OrderByKey() = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } compareValues(t, results) } func TestOrderByValue(t *testing.T) { scores := ref.Child("scores") results, err := scores.OrderByValue().LimitToLast(2).GetOrdered(context.Background()) if err != nil { t.Fatal(err) } want := []string{"linhenykus", "pterodactyl"} if len(results) != len(want) { t.Errorf("OrderByValue() = %d; want = %d", len(results), len(want)) } got := getNames(results) if !reflect.DeepEqual(got, want) { t.Errorf("LimitToLast() = %v; want = %v", got, want) } wantScores := []int{80, 93} for i, r := range results { var val int if err := r.Unmarshal(&val); err != nil { t.Fatalf("queryNode.Unmarshal() = %v", err) } if val != wantScores[i] { t.Errorf("queryNode.Unmarshal() = %d; want = %d", val, wantScores[i]) } } } func TestQueryWithContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) q := dinos.OrderByKey().LimitToFirst(2) var m map[string]Dinosaur if err := q.Get(ctx, &m); err != nil { t.Fatal(err) } want := []string{"bruhathkayosaurus", "lambeosaurus"} if len(m) != len(want) { t.Errorf("OrderByKey() = %d; want = %d", len(m), len(want)) } cancel() m = nil if err := q.Get(ctx, &m); len(m) != 0 || err == nil { t.Errorf("Get() = (%v, %v); want = (empty, error)", m, err) } } func TestUnorderedQuery(t *testing.T) { var m map[string]Dinosaur if err := dinos.OrderByChild("height"). StartAt(2.5). EndAt(5). Get(context.Background(), &m); err != nil { t.Fatal(err) } want := heightSorted[len(heightSorted)-3 : len(heightSorted)-1] if len(m) != len(want) { t.Errorf("Get() = %d; want = %d", len(m), len(want)) } for i, w := range want { if _, ok := m[w]; !ok { t.Errorf("[%d] result[%q] not present", i, w) } } } func min(i, j int) int { if i < j { return i } return j } func getNames(results []db.QueryNode) []string { s := make([]string, len(results)) for i, v := range results { s[i] = v.Key() } return s } func compareValues(t *testing.T, results []db.QueryNode) { for _, r := range results { var d Dinosaur if err := r.Unmarshal(&d); err != nil { t.Fatalf("queryNode.Unmarshal(%q) = %v", r.Key(), err) } if !reflect.DeepEqual(d, parsedTestData[r.Key()]) { t.Errorf("queryNode.Unmarshal(%q) = %v; want = %v", r.Key(), d, parsedTestData[r.Key()]) } } } golang-google-firebase-go-4.18.0/integration/firestore/000077500000000000000000000000001505612111400230275ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/firestore/firestore_test.go000066400000000000000000000030071505612111400264170ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package firestore import ( "context" "log" "reflect" "testing" "firebase.google.com/go/v4/integration/internal" ) func TestFirestore(t *testing.T) { if testing.Short() { log.Println("skipping Firestore integration tests in short mode.") return } ctx := context.Background() app, err := internal.NewTestApp(ctx, nil) if err != nil { t.Fatal(err) } client, err := app.Firestore(ctx) if err != nil { t.Fatal(err) } doc := client.Collection("cities").Doc("Mountain View") data := map[string]interface{}{ "name": "Mountain View", "country": "USA", "population": int64(77846), "capital": false, } if _, err := doc.Set(ctx, data); err != nil { t.Fatal(err) } snap, err := doc.Get(ctx) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(snap.Data(), data) { t.Errorf("Get() = %v; want %v", snap.Data(), data) } if _, err := doc.Delete(ctx); err != nil { t.Fatal(err) } } golang-google-firebase-go-4.18.0/integration/iid/000077500000000000000000000000001505612111400215725ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/iid/iid_test.go000066400000000000000000000033561505612111400237340ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package iid contains integration tests for the firebase.google.com/go/iid package. package iid import ( "context" "flag" "log" "os" "testing" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/iid" "firebase.google.com/go/v4/integration/internal" ) var client *iid.Client func TestMain(m *testing.M) { flag.Parse() if testing.Short() { log.Println("skipping instance ID integration tests in short mode.") os.Exit(0) } ctx := context.Background() app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } client, err = app.InstanceID(ctx) if err != nil { log.Fatalln(err) } os.Exit(m.Run()) } func TestNonExisting(t *testing.T) { // legal instance IDs are /[cdef][A-Za-z0-9_-]{9}[AEIMQUYcgkosw048]/ // "fictive-ID0" is match for that. err := client.DeleteInstanceID(context.Background(), "fictive-ID0") if err == nil { t.Errorf("DeleteInstanceID(non-existing) = nil; want error") } want := `instance id "fictive-ID0": failed to find the instance id` if !errorutils.IsNotFound(err) || err.Error() != want { t.Errorf("DeleteInstanceID(non-existing) = %v; want = %v", err, want) } } golang-google-firebase-go-4.18.0/integration/internal/000077500000000000000000000000001505612111400226415ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/internal/internal.go000066400000000000000000000052301505612111400250040ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package internal contains utilities for running integration tests. package internal import ( "context" "encoding/json" "io/ioutil" "net/http" "path/filepath" "strings" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" "google.golang.org/api/transport" ) const certPath = "integration_cert.json" const apiKeyPath = "integration_apikey.txt" // Resource returns the absolute path to the specified test resource file. func Resource(name string) string { p := []string{"..", "..", "testdata", name} return filepath.Join(p...) } // NewTestApp creates a new App instance for integration tests. // // NewTestApp looks for a service account JSON file named integration_cert.json // in the testdata directory. This file is used to initialize the newly created // App instance. func NewTestApp(ctx context.Context, conf *firebase.Config) (*firebase.App, error) { return firebase.NewApp(ctx, conf, option.WithCredentialsFile(Resource(certPath))) } // APIKey fetches a Firebase API key for integration tests. // // APIKey reads the API key string from a file named integration_apikey.txt // in the testdata directory. func APIKey() (string, error) { b, err := ioutil.ReadFile(Resource(apiKeyPath)) if err != nil { return "", err } return strings.TrimSpace(string(b)), nil } // ProjectID fetches a Google Cloud project ID for integration tests. func ProjectID() (string, error) { b, err := ioutil.ReadFile(Resource(certPath)) if err != nil { return "", err } var serviceAccount struct { ProjectID string `json:"project_id"` } if err := json.Unmarshal(b, &serviceAccount); err != nil { return "", err } return serviceAccount.ProjectID, nil } // NewHTTPClient creates an HTTP client for making authorized requests during tests. func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*http.Client, error) { opts = append( opts, option.WithCredentialsFile(Resource(certPath)), option.WithScopes(internal.FirebaseScopes...), ) hc, _, err := transport.NewHTTPClient(ctx, opts...) return hc, err } golang-google-firebase-go-4.18.0/integration/messaging/000077500000000000000000000000001505612111400230025ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/messaging/messaging_test.go000066400000000000000000000241171505612111400263520ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "context" "errors" "flag" "fmt" "log" "os" "regexp" "testing" "firebase.google.com/go/v4/integration/internal" "firebase.google.com/go/v4/messaging" ) // The registration token has the proper format, but is not valid (i.e. expired). The intention of // these integration tests is to verify that the endpoints return the proper payload, but it is // hard to ensure this token remains valid. The tests below should still pass regardless. const testRegistrationToken = "fGw0qy4TGgk:APA91bGtWGjuhp4WRhHXgbabIYp1jxEKI08ofj_v1bKhWAGJQ4e3a" + "rRCWzeTfHaLz83mBnDh0aPWB1AykXAVUUGl2h1wT4XI6XazWpvY7RBUSYfoxtqSWGIm2nvWh2BOP1YG501SsRoE" var messageIDPattern = regexp.MustCompile("^projects/.*/messages/.*$") var client *messaging.Client // Enable API before testing // https://console.developers.google.com/apis/library/fcm.googleapis.com func TestMain(m *testing.M) { flag.Parse() if testing.Short() { log.Println("Skipping messaging integration tests in short mode.") return } ctx := context.Background() app, err := internal.NewTestApp(ctx, nil) if err != nil { log.Fatalln(err) } client, err = app.Messaging(ctx) if err != nil { log.Fatalln(err) } os.Exit(m.Run()) } func TestSend(t *testing.T) { msg := &messaging.Message{ Topic: "foo-bar", Notification: &messaging.Notification{ Title: "Title", Body: "Body", }, Android: &messaging.AndroidConfig{ Notification: &messaging.AndroidNotification{ Title: "Android Title", Body: "Android Body", }, }, APNS: &messaging.APNSConfig{ Payload: &messaging.APNSPayload{ Aps: &messaging.Aps{ Alert: &messaging.ApsAlert{ Title: "APNS Title", Body: "APNS Body", }, }, }, }, Webpush: &messaging.WebpushConfig{ Notification: &messaging.WebpushNotification{ Title: "Webpush Title", Body: "Webpush Body", }, }, } name, err := client.SendDryRun(context.Background(), msg) if err != nil { t.Fatal(err) } if !messageIDPattern.MatchString(name) { t.Errorf("Send() = %q; want = %q", name, messageIDPattern.String()) } } func TestSendInvalidToken(t *testing.T) { msg := &messaging.Message{Token: "INVALID_TOKEN"} if _, err := client.Send(context.Background(), msg); err == nil || !messaging.IsInvalidArgument(err) { t.Errorf("Send() = %v; want InvalidArgumentError", err) } } func TestSendEach(t *testing.T) { messages := []*messaging.Message{ { Notification: &messaging.Notification{ Title: "Title 1", Body: "Body 1", }, Topic: "foo-bar", }, { Notification: &messaging.Notification{ Title: "Title 2", Body: "Body 2", }, Topic: "foo-bar", }, { Notification: &messaging.Notification{ Title: "Title 3", Body: "Body 3", }, Token: "INVALID_TOKEN", }, } br, err := client.SendEachDryRun(context.Background(), messages) if err != nil { t.Fatal(err) } if len(br.Responses) != 3 { t.Errorf("len(Responses) = %d; want = 3", len(br.Responses)) } if br.SuccessCount != 2 { t.Errorf("SuccessCount = %d; want = 2", br.SuccessCount) } if br.FailureCount != 1 { t.Errorf("FailureCount = %d; want = 1", br.FailureCount) } for i := 0; i < 2; i++ { sr := br.Responses[i] if err := checkSuccessfulSendResponse(sr); err != nil { t.Errorf("Responses[%d]: %v", i, err) } } sr := br.Responses[2] if sr.Success { t.Errorf("Responses[2]: Success = true; want = false") } if sr.MessageID != "" { t.Errorf("Responses[2]: MessageID = %q; want = %q", sr.MessageID, "") } if sr.Error == nil || !messaging.IsInvalidArgument(sr.Error) { t.Errorf("Responses[2]: Error = %v; want = InvalidArgumentError", sr.Error) } } func TestSendEachFiveHundred(t *testing.T) { var messages []*messaging.Message const limit = 500 for i := 0; i < limit; i++ { m := &messaging.Message{ Topic: fmt.Sprintf("foo-bar-%d", i%10), } messages = append(messages, m) } br, err := client.SendEachDryRun(context.Background(), messages) if err != nil { t.Fatal(err) } if len(br.Responses) != limit { t.Errorf("len(Responses) = %d; want = %d", len(br.Responses), limit) } if br.SuccessCount != limit { t.Errorf("SuccessCount = %d; want = %d", br.SuccessCount, limit) } if br.FailureCount != 0 { t.Errorf("FailureCount = %d; want = 0", br.FailureCount) } for i := 0; i < limit; i++ { sr := br.Responses[i] if err := checkSuccessfulSendResponse(sr); err != nil { t.Errorf("Responses[%d]: %v", i, err) } } } func TestSendEachForMulticast(t *testing.T) { message := &messaging.MulticastMessage{ Notification: &messaging.Notification{ Title: "title", Body: "body", }, Tokens: []string{"INVALID_TOKEN", "ANOTHER_INVALID_TOKEN"}, } br, err := client.SendEachForMulticastDryRun(context.Background(), message) if err != nil { t.Fatal(err) } if len(br.Responses) != 2 { t.Errorf("len(Responses) = %d; want = 2", len(br.Responses)) } if br.SuccessCount != 0 { t.Errorf("SuccessCount = %d; want = 0", br.SuccessCount) } if br.FailureCount != 2 { t.Errorf("FailureCount = %d; want = 2", br.FailureCount) } for i := 0; i < 2; i++ { sr := br.Responses[i] if err := checkErrorSendResponse(sr); err != nil { t.Errorf("Responses[%d]: %v", i, err) } } } func TestSendAll(t *testing.T) { t.Skip("Skipping integration tests for deprecated sendAll() API") messages := []*messaging.Message{ { Notification: &messaging.Notification{ Title: "Title 1", Body: "Body 1", }, Topic: "foo-bar", }, { Notification: &messaging.Notification{ Title: "Title 2", Body: "Body 2", }, Topic: "foo-bar", }, { Notification: &messaging.Notification{ Title: "Title 3", Body: "Body 3", }, Token: "INVALID_TOKEN", }, } br, err := client.SendAllDryRun(context.Background(), messages) if err != nil { t.Fatal(err) } if len(br.Responses) != 3 { t.Errorf("len(Responses) = %d; want = 3", len(br.Responses)) } if br.SuccessCount != 2 { t.Errorf("SuccessCount = %d; want = 2", br.SuccessCount) } if br.FailureCount != 1 { t.Errorf("FailureCount = %d; want = 1", br.FailureCount) } for i := 0; i < 2; i++ { sr := br.Responses[i] if err := checkSuccessfulSendResponse(sr); err != nil { t.Errorf("Responses[%d]: %v", i, err) } } sr := br.Responses[2] if sr.Success { t.Errorf("Responses[2]: Success = true; want = false") } if sr.MessageID != "" { t.Errorf("Responses[2]: MessageID = %q; want = %q", sr.MessageID, "") } if sr.Error == nil || !messaging.IsInvalidArgument(sr.Error) { t.Errorf("Responses[2]: Error = %v; want = InvalidArgumentError", sr.Error) } } func TestSendFiveHundred(t *testing.T) { t.Skip("Skipping integration tests for deprecated sendAll() API") var messages []*messaging.Message const limit = 500 for i := 0; i < limit; i++ { m := &messaging.Message{ Topic: fmt.Sprintf("foo-bar-%d", i%10), } messages = append(messages, m) } br, err := client.SendAllDryRun(context.Background(), messages) if err != nil { t.Fatal(err) } if len(br.Responses) != limit { t.Errorf("len(Responses) = %d; want = %d", len(br.Responses), limit) } if br.SuccessCount != limit { t.Errorf("SuccessCount = %d; want = %d", br.SuccessCount, limit) } if br.FailureCount != 0 { t.Errorf("FailureCount = %d; want = 0", br.FailureCount) } for i := 0; i < limit; i++ { sr := br.Responses[i] if err := checkSuccessfulSendResponse(sr); err != nil { t.Errorf("Responses[%d]: %v", i, err) } } } func TestSendMulticast(t *testing.T) { t.Skip("Skipping integration tests for deprecated SendMulticast() API") message := &messaging.MulticastMessage{ Notification: &messaging.Notification{ Title: "title", Body: "body", }, Tokens: []string{"INVALID_TOKEN", "ANOTHER_INVALID_TOKEN"}, } br, err := client.SendMulticastDryRun(context.Background(), message) if err != nil { t.Fatal(err) } if len(br.Responses) != 2 { t.Errorf("len(Responses) = %d; want = 2", len(br.Responses)) } if br.SuccessCount != 0 { t.Errorf("SuccessCount = %d; want = 0", br.SuccessCount) } if br.FailureCount != 2 { t.Errorf("FailureCount = %d; want = 2", br.FailureCount) } for i := 0; i < 2; i++ { sr := br.Responses[i] if err := checkErrorSendResponse(sr); err != nil { t.Errorf("Responses[%d]: %v", i, err) } } } func TestSubscribe(t *testing.T) { tmr, err := client.SubscribeToTopic(context.Background(), []string{testRegistrationToken}, "mock-topic") if err != nil { t.Fatal(err) } if tmr.SuccessCount+tmr.FailureCount != 1 { t.Errorf("SubscribeToTopic() = %v; want total 1", tmr) } } func TestUnsubscribe(t *testing.T) { tmr, err := client.UnsubscribeFromTopic(context.Background(), []string{testRegistrationToken}, "mock-topic") if err != nil { t.Fatal(err) } if tmr.SuccessCount+tmr.FailureCount != 1 { t.Errorf("UnsubscribeFromTopic() = %v; want total 1", tmr) } } func checkSuccessfulSendResponse(sr *messaging.SendResponse) error { if !sr.Success { return errors.New("Success = false; want = true") } if !messageIDPattern.MatchString(sr.MessageID) { return fmt.Errorf("MessageID = %q; want = %q", sr.MessageID, messageIDPattern.String()) } if sr.Error != nil { return fmt.Errorf("Error = %v; want = nil", sr.Error) } return nil } func checkErrorSendResponse(sr *messaging.SendResponse) error { if sr.Success { return fmt.Errorf("Success = true; want = false") } if sr.MessageID != "" { return fmt.Errorf("MessageID = %q; want = %q", sr.MessageID, "") } if sr.Error == nil || !messaging.IsInvalidArgument(sr.Error) { return fmt.Errorf("Error = %v; want = InvalidArgumentError", sr.Error) } return nil } golang-google-firebase-go-4.18.0/integration/storage/000077500000000000000000000000001505612111400224715ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/integration/storage/storage_test.go000066400000000000000000000055731505612111400255350ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package storage import ( "context" "flag" "fmt" "io/ioutil" "log" "os" "testing" gcs "cloud.google.com/go/storage" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/integration/internal" "firebase.google.com/go/v4/storage" ) var ctx context.Context var client *storage.Client func TestMain(m *testing.M) { flag.Parse() if testing.Short() { log.Println("skipping storage integration tests in short mode.") os.Exit(0) } pid, err := internal.ProjectID() if err != nil { log.Fatalln(err) } ctx = context.Background() app, err := internal.NewTestApp(ctx, &firebase.Config{ StorageBucket: fmt.Sprintf("%s.appspot.com", pid), }) if err != nil { log.Fatalln(err) } client, err = app.Storage(ctx) if err != nil { log.Fatalln(err) } os.Exit(m.Run()) } func TestDefaultBucket(t *testing.T) { bucket, err := client.DefaultBucket() if bucket == nil || err != nil { t.Errorf("DefaultBucket() = (%v, %v); want (bucket, nil)", bucket, err) } if err := verifyBucket(bucket); err != nil { t.Fatal(err) } } func TestCustomBucket(t *testing.T) { pid, err := internal.ProjectID() if err != nil { t.Fatal(err) } bucket, err := client.Bucket(pid + ".appspot.com") if bucket == nil || err != nil { t.Errorf("Bucket() = (%v, %v); want (bucket, nil)", bucket, err) } if err := verifyBucket(bucket); err != nil { t.Fatal(err) } } func TestNonExistingBucket(t *testing.T) { bucket, err := client.Bucket("non-existing") if bucket == nil || err != nil { t.Errorf("Bucket() = (%v, %v); want (bucket, nil)", bucket, err) } if _, err := bucket.Attrs(context.Background()); err == nil { t.Errorf("bucket.Attr() = nil; want error") } } func verifyBucket(bucket *gcs.BucketHandle) error { const expected = "Hello World" // Create new object o := bucket.Object("data") w := o.NewWriter(ctx) w.ContentType = "text/plain" if _, err := w.Write([]byte(expected)); err != nil { return err } if err := w.Close(); err != nil { return err } // Read the created object r, err := o.NewReader(ctx) if err != nil { return err } defer r.Close() b, err := ioutil.ReadAll(r) if err != nil { return err } if string(b) != expected { return fmt.Errorf("fetched content: %q; want: %q", string(b), expected) } // Delete the object return o.Delete(ctx) } golang-google-firebase-go-4.18.0/internal/000077500000000000000000000000001505612111400203165ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/internal/errors.go000066400000000000000000000134561505612111400221720ustar00rootroot00000000000000// Copyright 2020 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package internal import ( "encoding/json" "fmt" "net" "net/http" "net/url" "os" "syscall" ) // ErrorCode represents the platform-wide error codes that can be raised by // Admin SDK APIs. type ErrorCode string const ( // InvalidArgument is a OnePlatform error code. InvalidArgument ErrorCode = "INVALID_ARGUMENT" // FailedPrecondition is a OnePlatform error code. FailedPrecondition ErrorCode = "FAILED_PRECONDITION" // OutOfRange is a OnePlatform error code. OutOfRange ErrorCode = "OUT_OF_RANGE" // Unauthenticated is a OnePlatform error code. Unauthenticated ErrorCode = "UNAUTHENTICATED" // PermissionDenied is a OnePlatform error code. PermissionDenied ErrorCode = "PERMISSION_DENIED" // NotFound is a OnePlatform error code. NotFound ErrorCode = "NOT_FOUND" // Conflict is a custom error code that represents HTTP 409 responses. // // OnePlatform APIs typically respond with ABORTED or ALREADY_EXISTS explicitly. But a few // old APIs send HTTP 409 Conflict without any additional details to distinguish between the two // cases. For these we currently use this error code. As more APIs adopt OnePlatform conventions // this will become less important. Conflict ErrorCode = "CONFLICT" // Aborted is a OnePlatform error code. Aborted ErrorCode = "ABORTED" // AlreadyExists is a OnePlatform error code. AlreadyExists ErrorCode = "ALREADY_EXISTS" // ResourceExhausted is a OnePlatform error code. ResourceExhausted ErrorCode = "RESOURCE_EXHAUSTED" // Cancelled is a OnePlatform error code. Cancelled ErrorCode = "CANCELLED" // DataLoss is a OnePlatform error code. DataLoss ErrorCode = "DATA_LOSS" // Unknown is a OnePlatform error code. Unknown ErrorCode = "UNKNOWN" // Internal is a OnePlatform error code. Internal ErrorCode = "INTERNAL" // Unavailable is a OnePlatform error code. Unavailable ErrorCode = "UNAVAILABLE" // DeadlineExceeded is a OnePlatform error code. DeadlineExceeded ErrorCode = "DEADLINE_EXCEEDED" ) // FirebaseError is an error type containing an error code string. type FirebaseError struct { ErrorCode ErrorCode String string Response *http.Response Ext map[string]interface{} } func (fe *FirebaseError) Error() string { return fe.String } // HasPlatformErrorCode checks if the given error contains a specific error code. func HasPlatformErrorCode(err error, code ErrorCode) bool { fe, ok := err.(*FirebaseError) return ok && fe.ErrorCode == code } var httpStatusToErrorCodes = map[int]ErrorCode{ http.StatusBadRequest: InvalidArgument, http.StatusUnauthorized: Unauthenticated, http.StatusForbidden: PermissionDenied, http.StatusNotFound: NotFound, http.StatusConflict: Conflict, http.StatusTooManyRequests: ResourceExhausted, http.StatusInternalServerError: Internal, http.StatusServiceUnavailable: Unavailable, } // NewFirebaseError creates a new error from the given HTTP response. func NewFirebaseError(resp *Response) *FirebaseError { code, ok := httpStatusToErrorCodes[resp.Status] if !ok { code = Unknown } return &FirebaseError{ ErrorCode: code, String: fmt.Sprintf("unexpected http response with status: %d\n%s", resp.Status, string(resp.Body)), Response: resp.LowLevelResponse(), Ext: make(map[string]interface{}), } } // NewFirebaseErrorOnePlatform parses the response payload as a GCP error response // and create an error from the details extracted. // // If the response failes to parse, or otherwise doesn't provide any useful details // NewFirebaseErrorOnePlatform creates an error with some sensible defaults. func NewFirebaseErrorOnePlatform(resp *Response) *FirebaseError { base := NewFirebaseError(resp) var gcpError struct { Error struct { Status string `json:"status"` Message string `json:"message"` } `json:"error"` } json.Unmarshal(resp.Body, &gcpError) // ignore any json parse errors at this level if gcpError.Error.Status != "" { base.ErrorCode = ErrorCode(gcpError.Error.Status) } if gcpError.Error.Message != "" { base.String = gcpError.Error.Message } return base } func newFirebaseErrorTransport(err error) *FirebaseError { var code ErrorCode var msg string if os.IsTimeout(err) { code = DeadlineExceeded msg = fmt.Sprintf("timed out while making an http call: %v", err) } else if isConnectionRefused(err) { code = Unavailable msg = fmt.Sprintf("failed to establish a connection: %v", err) } else { code = Unknown msg = fmt.Sprintf("unknown error while making an http call: %v", err) } return &FirebaseError{ ErrorCode: code, String: msg, Ext: make(map[string]interface{}), } } // isConnectionRefused attempts to determine if the given error was caused by a failure to establish a // connection. // // A net.OpError where the Op field is set to "dial" or "read" is considered a connection refused // error. Similarly an ECONNREFUSED error code (Linux-specific) is also considered a connection // refused error. func isConnectionRefused(err error) bool { switch t := err.(type) { case *url.Error: return isConnectionRefused(t.Err) case *net.OpError: if t.Op == "dial" || t.Op == "read" { return true } return isConnectionRefused(t.Err) case syscall.Errno: return t == syscall.ECONNREFUSED } return false } golang-google-firebase-go-4.18.0/internal/errors_test.go000066400000000000000000000211441505612111400232220ustar00rootroot00000000000000// Copyright 2020 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package internal import ( "context" "encoding/json" "errors" "fmt" "io/ioutil" "net" "net/http" "net/http/httptest" "strings" "syscall" "testing" ) var platformErrorCodes = []ErrorCode{ InvalidArgument, Unauthenticated, NotFound, Aborted, AlreadyExists, Internal, Unavailable, Unknown, } func TestPlatformError(t *testing.T) { var body string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Write([]byte(body)) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, } want := "Test error message" for _, code := range platformErrorCodes { body = fmt.Sprintf(`{ "error": { "status": %q, "message": "Test error message" } }`, code) resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("[%s]: Do() = (%v, %v); want = (nil, %q)", code, resp, err, want) } if !HasPlatformErrorCode(err, code) { t.Errorf("[%s]: HasPlatformErrorCode() = false; want = true", code) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("[%s]: Do() err = %v; want = FirebaseError", code, err) } if fe.ErrorCode != code { t.Errorf("[%s]: Do() err.ErrorCode = %q; want = %q", code, fe.ErrorCode, code) } if fe.Response == nil { t.Fatalf("[%s]: Do() err.Response = nil; want = non-nil", code) } if fe.Response.StatusCode != http.StatusNotFound { t.Errorf("[%s]: Do() err.Response.StatusCode = %d; want = %d", code, fe.Response.StatusCode, http.StatusNotFound) } if fe.Ext == nil || len(fe.Ext) > 0 { t.Errorf("[%s]: Do() err.Ext = %v; want = empty-map", code, fe.Ext) } } } func TestPlatformErrorWithoutDetails(t *testing.T) { var status int handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, } httpStatusMappings := map[int]ErrorCode{ http.StatusNotImplemented: Unknown, } // Add known error code mappings for k, v := range httpStatusToErrorCodes { httpStatusMappings[k] = v } for httpStatus, platformCode := range httpStatusMappings { status = httpStatus want := fmt.Sprintf("unexpected http response with status: %d\n{}", httpStatus) resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("[%d]: Do() = (%v, %v); want = (nil, %q)", httpStatus, resp, err, want) } if !HasPlatformErrorCode(err, platformCode) { t.Errorf("[%d]: HasPlatformErrorCode(%q) = false; want = true", httpStatus, platformCode) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("[%d]: Do() err = %v; want = FirebaseError", httpStatus, err) } if fe.ErrorCode != platformCode { t.Errorf("[%d]: Do() err.ErrorCode = %q; want = %q", httpStatus, fe.ErrorCode, platformCode) } if fe.Response == nil { t.Fatalf("[%d]: Do() err.Response = nil; want = non-nil", httpStatus) } if fe.Response.StatusCode != httpStatus { t.Errorf("[%d]: Do() err.Response.StatusCode = %d; want = %d", httpStatus, fe.Response.StatusCode, httpStatus) } if fe.Ext == nil || len(fe.Ext) > 0 { t.Errorf("[%d]: Do() err.Ext = %v; want = empty-map", httpStatus, fe.Ext) } } } func TestTimeoutError(t *testing.T) { client := &HTTPClient{ Client: &http.Client{ Transport: &faultyTransport{ Err: &timeoutError{}, }, }, } get := &Request{ Method: http.MethodGet, URL: "http://test.url", } want := "timed out while making an http call" resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || !strings.HasPrefix(err.Error(), want) { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("Do() err = %v; want = FirebaseError", err) } if fe.ErrorCode != DeadlineExceeded { t.Errorf("Do() err.ErrorCode = %q; want = %q", fe.ErrorCode, DeadlineExceeded) } if fe.Response != nil { t.Errorf("Do() err.Response = %v; want = nil", fe.Response) } if fe.Ext == nil || len(fe.Ext) > 0 { t.Errorf("Do() err.Ext = %v; want = empty-map", fe.Ext) } } type timeoutError struct{} func (t *timeoutError) Error() string { return "test timeout error" } func (t *timeoutError) Timeout() bool { return true } func TestNetworkOutageError(t *testing.T) { errors := []struct { name string err error }{ {"NetDialError", &net.OpError{Op: "dial", Err: errors.New("test error")}}, {"NetReadError", &net.OpError{Op: "read", Err: errors.New("test error")}}, { "WrappedNetReadError", &net.OpError{ Op: "test", Err: &net.OpError{Op: "read", Err: errors.New("test error")}, }, }, {"ECONNREFUSED", syscall.ECONNREFUSED}, } get := &Request{ Method: http.MethodGet, URL: "http://test.url", } want := "failed to establish a connection" for _, tc := range errors { t.Run(tc.name, func(t *testing.T) { client := &HTTPClient{ Client: &http.Client{ Transport: &faultyTransport{ Err: tc.err, }, }, } resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || !strings.HasPrefix(err.Error(), want) { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("Do() err = %v; want = FirebaseError", err) } if fe.ErrorCode != Unavailable { t.Errorf("Do() err.ErrorCode = %q; want = %q", fe.ErrorCode, Unavailable) } if fe.Response != nil { t.Errorf("Do() err.Response = %v; want = nil", fe.Response) } if fe.Ext == nil || len(fe.Ext) > 0 { t.Errorf("Do() err.Ext = %v; want = empty-map", fe.Ext) } }) } } func TestUnknownNetworkError(t *testing.T) { client := &HTTPClient{ Client: &http.Client{ Transport: &faultyTransport{ Err: errors.New("unknown error"), }, }, } get := &Request{ Method: http.MethodGet, URL: "http://test.url", } want := "unknown error while making an http call" resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || !strings.HasPrefix(err.Error(), want) { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("Do() err = %v; want = FirebaseError", err) } if fe.ErrorCode != Unknown { t.Errorf("Do() err.ErrorCode = %q; want = %q", fe.ErrorCode, Unknown) } if fe.Response != nil { t.Errorf("Do() err.Response = %v; want = nil", fe.Response) } if fe.Ext == nil || len(fe.Ext) > 0 { t.Errorf("Do() err.Ext = %v; want = empty-map", fe.Ext) } } func TestErrorHTTPResponse(t *testing.T) { body := `{"key": "value"}` handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(body)) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, } want := fmt.Sprintf("unexpected http response with status: 500\n%s", body) resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("Do() err = %v; want = FirebaseError", err) } hr := fe.Response defer hr.Body.Close() if hr.StatusCode != http.StatusInternalServerError { t.Errorf("Do() Response.StatusCode = %d; want = %d", hr.StatusCode, http.StatusInternalServerError) } b, err := ioutil.ReadAll(hr.Body) if err != nil { t.Fatalf("ReadAll(Response.Body) = %v", err) } var m map[string]string if err := json.Unmarshal(b, &m); err != nil { t.Fatalf("Unmarshal(Response.Body) = %v", err) } if len(m) != 1 || m["key"] != "value" { t.Errorf("Unmarshal(Response.Body) = %v; want = {key: value}", m) } } golang-google-firebase-go-4.18.0/internal/http_client.go000066400000000000000000000312351505612111400231660ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package internal import ( "bytes" "context" "encoding/json" "fmt" "io" "io/ioutil" "math" "net/http" "runtime" "strconv" "strings" "time" "google.golang.org/api/option" "google.golang.org/api/transport" ) // HTTPClient is a convenient API to make HTTP calls. // // This API handles repetitive tasks such as entity serialization and deserialization // when making HTTP calls. It provides a convenient mechanism to set headers and query // parameters on outgoing requests, while enforcing that an explicit context is used per request. // Responses returned by HTTPClient can be easily unmarshalled as JSON. // // HTTPClient also handles automatically retrying failed HTTP requests. type HTTPClient struct { Client *http.Client RetryConfig *RetryConfig CreateErrFn CreateErrFn SuccessFn SuccessFn Opts []HTTPOption } // SuccessFn is a function that checks if a Response indicates success. type SuccessFn func(r *Response) bool // CreateErrFn is a function that creates an error from a given Response. type CreateErrFn func(r *Response) error // NewHTTPClient creates a new HTTPClient using the provided client options and the default // RetryConfig. // // NewHTTPClient returns the created HTTPClient along with the target endpoint URL. The endpoint // is obtained from the client options passed into the function. func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*HTTPClient, string, error) { hc, endpoint, err := transport.NewHTTPClient(ctx, opts...) if err != nil { return nil, "", err } return WithDefaultRetryConfig(hc), endpoint, nil } // WithDefaultRetryConfig creates a new HTTPClient using the provided client and the default // RetryConfig. // // The default RetryConfig retries requests on all low-level network errors as well as on HTTP // ServiceUnavailable (503) error. Repeatedly failing requests are retried up to 4 times // with exponential backoff. Retry delay is never longer than 2 minutes. func WithDefaultRetryConfig(hc *http.Client) *HTTPClient { twoMinutes := time.Duration(2) * time.Minute return &HTTPClient{ Client: hc, RetryConfig: &RetryConfig{ MaxRetries: 4, CheckForRetry: retryNetworkAndHTTPErrors( http.StatusServiceUnavailable, ), ExpBackoffFactor: 0.5, MaxDelay: &twoMinutes, }, } } // Request contains all the parameters required to construct an outgoing HTTP request. type Request struct { Method string URL string Body HTTPEntity Opts []HTTPOption SuccessFn SuccessFn CreateErrFn CreateErrFn } // Response contains information extracted from an HTTP response. type Response struct { Status int Header http.Header Body []byte resp *http.Response } // LowLevelResponse returns an http.Response that represents the underlying low-level HTTP // response. // // This always returns a buffered copy of the original HTTP response. Body can be read from the // returned response with no impact on the underlying HTTP connection. Closing the Body on the // returned response is a No-op. func (r *Response) LowLevelResponse() *http.Response { // If the Response instance was initialized manually (as is the case when parsing batch // responses) the resp field may be nil. if r.resp == nil { return nil } resp := *r.resp resp.Body = ioutil.NopCloser(bytes.NewBuffer(r.Body)) return &resp } // Do executes the given Request, and returns a Response. // // If a RetryConfig is specified on the client, Do attempts to retry failing requests. // // If SuccessFn is set on the client or on the request, the response is validated against that // function. If this validation fails, returns an error. These errors are created using the // CreateErrFn on the client or on the request. If neither is set, CreatePlatformError is // used as the default error function. func (c *HTTPClient) Do(ctx context.Context, req *Request) (*Response, error) { var result *attemptResult for retries := 0; ; retries++ { hr, err := req.buildHTTPRequest(c.Opts) if err != nil { return nil, err } result = c.attempt(ctx, hr, retries) if !result.Retry { break } if err = result.waitForRetry(ctx); err != nil { return nil, err } } return c.handleResult(req, result) } // DoAndUnmarshal behaves similar to Do, but additionally unmarshals the response payload into // the given pointer. // // Unmarshal takes place only if the response does not represent an error (as determined by // the Do function) and v is not nil. If the unmarshal fails, an error is returned even if the // original response indicated success. func (c *HTTPClient) DoAndUnmarshal(ctx context.Context, req *Request, v interface{}) (*Response, error) { resp, err := c.Do(ctx, req) if err != nil { return nil, err } if v != nil { if err := json.Unmarshal(resp.Body, v); err != nil { return nil, fmt.Errorf("error while parsing response: %v", err) } } return resp, nil } func (c *HTTPClient) attempt(ctx context.Context, hr *http.Request, retries int) *attemptResult { resp, err := c.Client.Do(hr.WithContext(ctx)) result := &attemptResult{} if err != nil { result.Err = err } else { // Read the response body here forcing any I/O errors to occur so that retry logic will // cover them as well. ir, err := newResponse(resp) result.Resp = ir result.Err = err } // If a RetryConfig is available, always consult it to determine if the request should be retried // or not. Even if there was a network error, we may not want to retry the request based on the // RetryConfig that is in effect. if c.RetryConfig != nil { delay, retry := c.RetryConfig.retryDelay(retries, resp, result.Err) result.RetryAfter = delay result.Retry = retry } return result } func (c *HTTPClient) handleResult(req *Request, result *attemptResult) (*Response, error) { if result.Err != nil { return nil, newFirebaseErrorTransport(result.Err) } if !c.success(req, result.Resp) { return nil, c.newError(req, result.Resp) } return result.Resp, nil } func (c *HTTPClient) success(req *Request, resp *Response) bool { var successFn SuccessFn if req.SuccessFn != nil { successFn = req.SuccessFn } else if c.SuccessFn != nil { successFn = c.SuccessFn } else { successFn = HasSuccessStatus } return successFn(resp) } func (c *HTTPClient) newError(req *Request, resp *Response) error { createErr := func(r *Response) error { return NewFirebaseErrorOnePlatform(r) } if req.CreateErrFn != nil { createErr = req.CreateErrFn } else if c.CreateErrFn != nil { createErr = c.CreateErrFn } return createErr(resp) } type attemptResult struct { Resp *Response Err error Retry bool RetryAfter time.Duration } func (r *attemptResult) waitForRetry(ctx context.Context) error { if r.RetryAfter > 0 { select { case <-ctx.Done(): case <-time.After(r.RetryAfter): } } return ctx.Err() } func (r *Request) buildHTTPRequest(opts []HTTPOption) (*http.Request, error) { var data io.Reader if r.Body != nil { b, err := r.Body.Bytes() if err != nil { return nil, err } data = bytes.NewBuffer(b) opts = append(opts, WithHeader("Content-Type", r.Body.Mime())) } req, err := http.NewRequest(r.Method, r.URL, data) if err != nil { return nil, err } opts = append(opts, r.Opts...) for _, o := range opts { o(req) } return req, nil } // HTTPEntity represents a payload that can be included in an outgoing HTTP request. type HTTPEntity interface { Bytes() ([]byte, error) Mime() string } type jsonEntity struct { Val interface{} } // NewJSONEntity creates a new HTTPEntity that will be serialized into JSON. func NewJSONEntity(v interface{}) HTTPEntity { return &jsonEntity{Val: v} } func (e *jsonEntity) Bytes() ([]byte, error) { return json.Marshal(e.Val) } func (e *jsonEntity) Mime() string { return "application/json" } func newResponse(resp *http.Response) (*Response, error) { defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } return &Response{ Status: resp.StatusCode, Body: b, Header: resp.Header, resp: resp, }, nil } // HTTPOption is an additional parameter that can be specified to customize an outgoing request. type HTTPOption func(*http.Request) // WithHeader creates an HTTPOption that will set an HTTP header on the request. func WithHeader(key, value string) HTTPOption { return func(r *http.Request) { r.Header.Set(key, value) } } // WithQueryParam creates an HTTPOption that will set a query parameter on the request. func WithQueryParam(key, value string) HTTPOption { return func(r *http.Request) { q := r.URL.Query() q.Add(key, value) r.URL.RawQuery = q.Encode() } } // WithQueryParams creates an HTTPOption that will set all the entries of qp as query parameters // on the request. func WithQueryParams(qp map[string]string) HTTPOption { return func(r *http.Request) { q := r.URL.Query() for k, v := range qp { q.Add(k, v) } r.URL.RawQuery = q.Encode() } } // HasSuccessStatus returns true if the response status code is in the 2xx range. func HasSuccessStatus(r *Response) bool { return r.Status >= http.StatusOK && r.Status < http.StatusNotModified } // RetryConfig specifies how the HTTPClient should retry failing HTTP requests. // // A request is never retried more than MaxRetries times. If CheckForRetry is nil, all network // errors, and all 400+ HTTP status codes are retried. If an HTTP error response contains the // Retry-After header, it is always respected. Otherwise retries are delayed with exponential // backoff. Set ExpBackoffFactor to 0 to disable exponential backoff, and retry immediately // after each error. // // If MaxDelay is set, retries delay gets capped by that value. If the Retry-After header // requires a longer delay than MaxDelay, retries are not attempted. type RetryConfig struct { MaxRetries int CheckForRetry RetryCondition ExpBackoffFactor float64 MaxDelay *time.Duration } // RetryCondition determines if an HTTP request should be retried depending on its last outcome. type RetryCondition func(resp *http.Response, networkErr error) bool func (rc *RetryConfig) retryDelay(retries int, resp *http.Response, err error) (time.Duration, bool) { if !rc.retryEligible(retries, resp, err) { return 0, false } estimatedDelay := rc.estimateDelayBeforeNextRetry(retries) serverRecommendedDelay := parseRetryAfterHeader(resp) if serverRecommendedDelay > estimatedDelay { estimatedDelay = serverRecommendedDelay } if rc.MaxDelay != nil && estimatedDelay > *rc.MaxDelay { return 0, false } return estimatedDelay, true } func (rc *RetryConfig) retryEligible(retries int, resp *http.Response, err error) bool { if retries >= rc.MaxRetries { return false } if rc.CheckForRetry == nil { return err != nil || resp.StatusCode >= 500 } return rc.CheckForRetry(resp, err) } func (rc *RetryConfig) estimateDelayBeforeNextRetry(retries int) time.Duration { if retries == 0 { return 0 } delayInSeconds := int64(math.Pow(2, float64(retries)) * rc.ExpBackoffFactor) estimatedDelay := time.Duration(delayInSeconds) * time.Second if rc.MaxDelay != nil && estimatedDelay > *rc.MaxDelay { estimatedDelay = *rc.MaxDelay } return estimatedDelay } var retryTimeClock Clock = SystemClock func parseRetryAfterHeader(resp *http.Response) time.Duration { if resp == nil { return 0 } retryAfterHeader := resp.Header.Get("retry-after") if retryAfterHeader == "" { return 0 } if delayInSeconds, err := strconv.ParseInt(retryAfterHeader, 10, 64); err == nil { return time.Duration(delayInSeconds) * time.Second } if timestamp, err := http.ParseTime(retryAfterHeader); err == nil { return timestamp.Sub(retryTimeClock.Now()) } return 0 } func retryNetworkAndHTTPErrors(statusCodes ...int) RetryCondition { return func(resp *http.Response, networkErr error) bool { if networkErr != nil { return true } for _, retryOnStatus := range statusCodes { if resp.StatusCode == retryOnStatus { return true } } return false } } // GetMetricsHeader constructs header value for metrics attribution func GetMetricsHeader(sdkVersion string) string { goVersion := strings.TrimPrefix(runtime.Version(), "go") return fmt.Sprintf("gl-go/%s fire-admin/%s", goVersion, sdkVersion) } golang-google-firebase-go-4.18.0/internal/http_client_test.go000066400000000000000000000554431505612111400242340ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package internal import ( "context" "encoding/json" "errors" "fmt" "io/ioutil" "net/http" "net/http/httptest" "reflect" "testing" "time" "google.golang.org/api/option" ) const defaultMaxRetries = 4 var ( testRetryConfig = RetryConfig{ MaxRetries: 4, ExpBackoffFactor: 0.5, } tokenSourceOpt = option.WithTokenSource(&MockTokenSource{AccessToken: "test"}) ) var testRequests = []struct { req *Request method string body string headers map[string]string query map[string]string }{ { req: &Request{ Method: http.MethodGet, }, method: http.MethodGet, }, { req: &Request{ Method: http.MethodGet, Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParam("testParam", "value2"), }, }, method: http.MethodGet, headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam": "value2"}, }, { req: &Request{ Method: http.MethodPost, Body: NewJSONEntity(map[string]string{"foo": "bar"}), Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParam("testParam1", "value2"), WithQueryParam("testParam2", "value3"), }, }, method: http.MethodPost, body: "{\"foo\":\"bar\"}", headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, }, { req: &Request{ Method: http.MethodPost, Body: NewJSONEntity("body"), Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParams(map[string]string{"testParam1": "value2", "testParam2": "value3"}), }, }, method: http.MethodPost, body: "\"body\"", headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, }, { req: &Request{ Method: http.MethodPut, Body: NewJSONEntity(nil), Opts: []HTTPOption{ WithHeader("Test-Header", "value1"), WithQueryParams(map[string]string{"testParam1": "value2", "testParam2": "value3"}), }, }, method: http.MethodPut, body: "null", headers: map[string]string{"Test-Header": "value1"}, query: map[string]string{"testParam1": "value2", "testParam2": "value3"}, }, } func TestHTTPClient(t *testing.T) { want := map[string]interface{}{ "key1": "value1", "key2": float64(100), } b, err := json.Marshal(want) if err != nil { t.Fatal(err) } idx := 0 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { want := testRequests[idx] if r.Method != want.method { t.Errorf("[%d] Method = %q; want = %q", idx, r.Method, want.method) } for k, v := range want.headers { h := r.Header.Get(k) if h != v { t.Errorf("[%d] Header(%q) = %q; want = %q", idx, k, h, v) } } if want.query == nil { if r.URL.Query().Encode() != "" { t.Errorf("[%d] Query = %v; want = empty", idx, r.URL.Query().Encode()) } } for k, v := range want.query { q := r.URL.Query().Get(k) if q != v { t.Errorf("[%d] Query(%q) = %q; want = %q", idx, k, q, v) } } if want.body != "" { h := r.Header.Get("Content-Type") if h != "application/json" { t.Errorf("[%d] Content-Type = %q; want = %q", idx, h, "application/json") } wb := []byte(want.body) gb, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(wb, gb) { t.Errorf("[%d] Body = %q; want = %q", idx, string(gb), string(wb)) } } idx++ w.Header().Set("Content-Type", "application/json") w.Write(b) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{Client: http.DefaultClient} for _, tc := range testRequests { tc.req.URL = server.URL var got map[string]interface{} resp, err := client.DoAndUnmarshal(context.Background(), tc.req, &got) if err != nil { t.Fatal(err) } if resp.Status != http.StatusOK { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) } if !reflect.DeepEqual(got, want) { t.Errorf("Body = %v; want = %v", got, want) } } } func TestDefaultOpts(t *testing.T) { var header string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header = r.Header.Get("Test-Header") w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, Opts: []HTTPOption{ WithHeader("Test-Header", "test-value"), }, } req := &Request{ Method: http.MethodGet, URL: fmt.Sprintf("%s%s", server.URL, wantURL), } resp, err := client.Do(context.Background(), req) if err != nil { t.Fatal(err) } if resp.Status != http.StatusOK { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) } if header != "test-value" { t.Errorf("Test-Header = %q; want = %q", header, "test-value") } } func TestSuccessFn(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, SuccessFn: func(r *Response) bool { return false }, } get := &Request{ Method: http.MethodGet, URL: server.URL, } want := "unexpected http response with status: 200\n{}" resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } fe, ok := err.(*FirebaseError) if !ok { t.Fatalf("Do() err = %v; want = FirebaseError", err) } if fe.ErrorCode != Unknown { t.Errorf("Do() err.ErrorCode = %q; want = %q", fe.ErrorCode, Unknown) } if fe.Response == nil { t.Fatalf("Do() err.Response = nil; want = non-nil") } if fe.Response.StatusCode != http.StatusOK { t.Errorf("Do() err.Response.StatusCode = %d; want = %d", fe.Response.StatusCode, http.StatusOK) } } func TestSuccessFnOnRequest(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, SuccessFn: func(r *Response) bool { return false }, } want := "unexpected http response with status: 200\n{}" resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } if !HasPlatformErrorCode(err, Unknown) { t.Errorf("ErrorCode = %q; want = %q", err.(*FirebaseError).ErrorCode, Unknown) } } func TestCreateErrFn(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, CreateErrFn: func(r *Response) error { return fmt.Errorf("custom error with status: %d", r.Status) }, } get := &Request{ Method: http.MethodGet, URL: server.URL, } want := "custom error with status: 404" resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } } func TestCreateErrFnOnRequest(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, CreateErrFn: func(r *Response) error { return fmt.Errorf("custom error with status: %d", r.Status) }, } get := &Request{ Method: http.MethodGet, URL: server.URL, CreateErrFn: func(r *Response) error { return fmt.Errorf("custom error from req with status: %d", r.Status) }, } want := "custom error from req with status: 404" resp, err := client.Do(context.Background(), get) if resp != nil || err == nil || err.Error() != want { t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) } } func TestContext(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{Client: http.DefaultClient} ctx, cancel := context.WithCancel(context.Background()) resp, err := client.Do(ctx, &Request{ Method: http.MethodGet, URL: server.URL, }) if err != nil { t.Fatal(err) } cancel() resp, err = client.Do(ctx, &Request{ Method: http.MethodGet, URL: server.URL, }) if resp != nil || err == nil { t.Errorf("Do() = (%v; %v); want = (nil, error)", resp, err) } } func TestInvalidURL(t *testing.T) { req := &Request{ Method: http.MethodGet, URL: "http://localhost:250/mock.url", } client := &HTTPClient{Client: http.DefaultClient} if _, err := client.Do(context.Background(), req); err == nil { t.Errorf("Send() = nil; want error") } } func TestUnmarshalError(t *testing.T) { data := map[string]interface{}{ "foo": "bar", } b, err := json.Marshal(data) if err != nil { t.Fatal(err) } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write(b) }) server := httptest.NewServer(handler) defer server.Close() req := &Request{Method: http.MethodGet, URL: server.URL} client := &HTTPClient{Client: http.DefaultClient} var got func() _, err = client.DoAndUnmarshal(context.Background(), req, &got) if err == nil { t.Errorf("DoAndUnmarshal() = nil; want error") } } func TestRetryDisabled(t *testing.T) { requests := 0 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requests++ w.WriteHeader(http.StatusServiceUnavailable) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, RetryConfig: nil, SuccessFn: acceptAll, } req := &Request{Method: http.MethodGet, URL: server.URL} resp, err := client.Do(context.Background(), req) if err != nil { t.Fatal(err) } if resp.Status != http.StatusServiceUnavailable { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusServiceUnavailable) } if requests != 1 { t.Errorf("Total requests = %d; want = 1", requests) } } func TestNetworkErrorMaxRetries(t *testing.T) { err := errors.New("network error") maxRetries := testRetryConfig.MaxRetries for i := 0; i < maxRetries; i++ { if eligible := testRetryConfig.retryEligible(i, nil, err); !eligible { t.Errorf("retryEligible(%d, nil, err) = false; want = true", i) } } if eligible := testRetryConfig.retryEligible(maxRetries, nil, err); eligible { t.Errorf("retryEligible(%d, nil, err) = true; want = false", maxRetries) } } func TestHTTPErrorMaxRetries(t *testing.T) { resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, } maxRetries := testRetryConfig.MaxRetries for i := 0; i < maxRetries; i++ { if eligible := testRetryConfig.retryEligible(i, resp, nil); !eligible { t.Errorf("retryEligible(%d, 503, nil) = false; want = true", i) } } if eligible := testRetryConfig.retryEligible(maxRetries, resp, nil); eligible { t.Errorf("retryEligible(%d, 503, nil) = true; want = false", maxRetries) } } func TestNoRetryOnRequestBuildError(t *testing.T) { client := &HTTPClient{ Client: http.DefaultClient, RetryConfig: &testRetryConfig, } entity := &faultyEntity{} req := &Request{ Method: http.MethodGet, URL: "https://firebase.google.com", Body: entity, } if _, err := client.Do(context.Background(), req); err == nil { t.Errorf("Do() = nil; want = error") } if entity.RequestAttempts != 1 { t.Errorf("Request attempts = %d; want = 1", entity.RequestAttempts) } } func TestNoRetryOnInvalidMethod(t *testing.T) { client := &HTTPClient{ Client: http.DefaultClient, RetryConfig: &testRetryConfig, } req := &Request{ Method: "Invalid/Method", URL: "https://firebase.google.com", } if _, err := client.Do(context.Background(), req); err == nil { t.Errorf("Do() = nil; want = error") } } func TestNoRetryOnHTTPSuccessCodes(t *testing.T) { for i := http.StatusOK; i < http.StatusBadRequest; i++ { resp := &http.Response{ StatusCode: i, } if eligible := testRetryConfig.retryEligible(0, resp, nil); eligible { t.Errorf("retryEligible(%d, %d, nil) = true; want = false", i, resp.StatusCode) } } } func TestRetryOnHTTPErrorCodes(t *testing.T) { for i := http.StatusInternalServerError; i <= http.StatusNetworkAuthenticationRequired; i++ { resp := &http.Response{ StatusCode: i, } if eligible := testRetryConfig.retryEligible(0, resp, nil); !eligible { t.Errorf("retryEligible(%d, %d, nil) = false; want = true", i, resp.StatusCode) } } } func TestRetryAfterHeaderInSecondsFormat(t *testing.T) { header := make(http.Header) header.Add("retry-after", "30") resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, Header: header, } maxRetries := testRetryConfig.MaxRetries for i := 0; i < maxRetries; i++ { delay, ok := testRetryConfig.retryDelay(i, resp, nil) if !ok || delay != time.Duration(30)*time.Second { t.Errorf("retryDelay(%d) = (%f, %v); want = (30.0, true)", i, delay.Seconds(), ok) } } delay, ok := testRetryConfig.retryDelay(maxRetries, resp, nil) if ok || delay != 0 { t.Errorf("retryDelay(%d) = (%f, %v); want = (0.0, false)", maxRetries, delay.Seconds(), ok) } } func TestRetryAfterHeaderInTimestampFormat(t *testing.T) { header := make(http.Header) now := time.Now() retryAfter := now.Add(time.Duration(60) * time.Second) // http.TimeFormat requires the time be in UTC. header.Add("retry-after", retryAfter.UTC().Format(http.TimeFormat)) resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, Header: header, } retryTimeClock = &MockClock{now} maxRetries := testRetryConfig.MaxRetries for i := 0; i < maxRetries; i++ { delay, ok := testRetryConfig.retryDelay(i, resp, nil) // HTTP timestamp format has seconds precision. So the final value could be off by 1s. if !ok || delay < time.Duration(60-1)*time.Second || delay > time.Duration(60+1)*time.Second { t.Errorf("retryDelay(%d) = (%f, %v); want = (~60.0, true)", i, delay.Seconds(), ok) } } delay, ok := testRetryConfig.retryDelay(maxRetries, resp, nil) if ok || delay != 0 { t.Errorf("retryDelay(%d) = (%f, %v); want = (0.0, false)", maxRetries, delay.Seconds(), ok) } } func TestMaxDelayWithRetryAfterHeader(t *testing.T) { header := make(http.Header) header.Add("retry-after", "30") resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, Header: header, } tenSeconds := time.Duration(10) * time.Second rc := &RetryConfig{ MaxRetries: 4, MaxDelay: &tenSeconds, } delay, ok := rc.retryDelay(0, resp, nil) if ok || delay != 0 { t.Errorf("retryDelay() = (%f, %v); want = (0.0, false)", delay.Seconds(), ok) } } func TestRetryDelayExpBackoff(t *testing.T) { want := []int{0, 1, 2, 4} resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, } maxRetries := testRetryConfig.MaxRetries for i := 0; i < maxRetries; i++ { delay, ok := testRetryConfig.retryDelay(i, resp, nil) if !ok || delay != time.Duration(want[i])*time.Second { t.Errorf("retryDelay(%d) = (%f, %v); want = (%d, true)", i, delay.Seconds(), ok, want[i]) } } delay, ok := testRetryConfig.retryDelay(maxRetries, resp, nil) if ok || delay != 0 { t.Errorf("retryDelay(%d) = (%f, %v); want = (0, false)", maxRetries, delay.Seconds(), ok) } } func TestMaxDelayWithExpBackoff(t *testing.T) { want := []int{0, 2, 4, 5, 5} fiveSeconds := time.Duration(5) * time.Second rc := &RetryConfig{ MaxRetries: 5, MaxDelay: &fiveSeconds, ExpBackoffFactor: 1, } resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, } for i := 0; i < 5; i++ { delay, ok := rc.retryDelay(i, resp, nil) if !ok || delay != time.Duration(want[i])*time.Second { t.Errorf("retryDelay(%d) = (%f, %v); want = (%d, true)", i, delay.Seconds(), ok, want[i]) } } } func TestRetryDelayDisableExponentialBackoff(t *testing.T) { resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, } rc := &RetryConfig{ MaxRetries: 4, ExpBackoffFactor: 0, } for i := 0; i < 4; i++ { delay, ok := rc.retryDelay(i, resp, nil) if !ok || delay != 0 { t.Errorf("retryDelay(%d) = (%f, %v); want = (0, true)", i, delay.Seconds(), ok) } } } func TestLongestRetryDelayHasPrecedence(t *testing.T) { header := make(http.Header) header.Add("retry-after", "3") resp := &http.Response{ StatusCode: http.StatusServiceUnavailable, Header: header, } want := []int{0, 1, 2, 4} for i := 0; i < 4; i++ { delay, ok := testRetryConfig.retryDelay(i, resp, nil) if !ok { t.Errorf("retryDelay(%d) = false; want = true", i) } if want[i] <= 3 { if delay < time.Duration(3-1)*time.Second || delay > time.Duration(3+1)*time.Second { t.Errorf("retryDelay(%d) = %f; want = ~3.0", i, delay.Seconds()) } } else { if delay != time.Duration(want[i])*time.Second { t.Errorf("retryDelay(%d) = %f; want = %d", i, delay.Seconds(), want[i]) } } } } func TestContextCancellationStopsRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) result := &attemptResult{} if err := result.waitForRetry(ctx); err != nil { t.Fatalf("prepareRequest() = %v; want = nil", err) } cancel() if err := result.waitForRetry(ctx); err != context.Canceled { t.Errorf("prepareRequest() = %v; want = %v", err, context.Canceled) } } func TestNewHTTPClient(t *testing.T) { wantEndpoint := "https://cloud.google.com" opts := []option.ClientOption{ tokenSourceOpt, option.WithEndpoint(wantEndpoint), } client, endpoint, err := NewHTTPClient(context.Background(), opts...) if err != nil { t.Fatal(err) } wantRetry := &RetryConfig{ MaxRetries: 4, ExpBackoffFactor: 0.5, } gotRetry := client.RetryConfig if gotRetry.MaxRetries != wantRetry.MaxRetries || gotRetry.ExpBackoffFactor != wantRetry.ExpBackoffFactor || gotRetry.CheckForRetry == nil { t.Errorf("NewHTTPClient().RetryConfig = %v; want = %v", *gotRetry, wantRetry) } if endpoint != wantEndpoint { t.Errorf("NewHTTPClient() = %q; want = %q", endpoint, wantEndpoint) } } func TestNewHTTPClientRetryOnNetworkErrors(t *testing.T) { client, _, err := NewHTTPClient(context.Background(), tokenSourceOpt) if err != nil { t.Fatal(err) } tansport := &faultyTransport{} client.Client.Transport = tansport client.RetryConfig.ExpBackoffFactor = 0 req := &Request{Method: http.MethodGet, URL: "http://firebase.google.com"} resp, err := client.Do(context.Background(), req) if resp != nil || err == nil { t.Errorf("Do() = (%v, %v); want = (nil, error)", resp, err) } wantRequests := 1 + defaultMaxRetries if tansport.RequestAttempts != wantRequests { t.Errorf("Total requests = %d; want = %d", tansport.RequestAttempts, wantRequests) } } func TestNewHTTPClientRetryOnHTTPErrors(t *testing.T) { var status int requests := 0 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requests++ w.WriteHeader(status) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client, _, err := NewHTTPClient(context.Background(), tokenSourceOpt) if err != nil { t.Fatal(err) } client.RetryConfig.ExpBackoffFactor = 0 client.SuccessFn = acceptAll for _, status = range []int{http.StatusServiceUnavailable} { requests = 0 req := &Request{Method: http.MethodGet, URL: server.URL} resp, err := client.Do(context.Background(), req) if err != nil { t.Fatal(err) } if resp.Status != status { t.Errorf("Status = %d; want = %d", resp.Status, status) } wantRequests := 1 + defaultMaxRetries if requests != wantRequests { t.Errorf("Total requests = %d; want = %d", requests, wantRequests) } } } func TestNewHttpClientNoRetryOnNotFound(t *testing.T) { requests := 0 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requests++ w.WriteHeader(http.StatusNotFound) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) }) server := httptest.NewServer(handler) defer server.Close() client, _, err := NewHTTPClient(context.Background(), tokenSourceOpt) if err != nil { t.Fatal(err) } client.SuccessFn = acceptAll req := &Request{Method: http.MethodGet, URL: server.URL} resp, err := client.Do(context.Background(), req) if err != nil { t.Fatal(err) } if resp.Status != http.StatusNotFound { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusNotFound) } if requests != 1 { t.Errorf("Total requests = %d; want = 1", requests) } } func TestNewHttpClientRetryOnResponseReadError(t *testing.T) { requests := 0 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requests++ // Lie about the content-length forcing a read error on the client w.Header().Set("Content-Length", "1") }) server := httptest.NewServer(handler) defer server.Close() client, _, err := NewHTTPClient(context.Background(), tokenSourceOpt) if err != nil { t.Fatal(err) } client.RetryConfig.ExpBackoffFactor = 0 req := &Request{Method: http.MethodGet, URL: server.URL} resp, err := client.Do(context.Background(), req) if resp != nil || err == nil { t.Errorf("Do() = (%v, %v); want = (nil, error)", resp, err) } wantRequests := 1 + defaultMaxRetries if requests != wantRequests { t.Errorf("Total requests = %d; want = %d", requests, wantRequests) } } func TestNilLowLevelResponse(t *testing.T) { r := &Response{ resp: nil, } if ll := r.LowLevelResponse(); ll != nil { t.Errorf("LowLevelResponse() = %v; want = nil", ll) } } type faultyEntity struct { RequestAttempts int } func (e *faultyEntity) Bytes() ([]byte, error) { e.RequestAttempts++ return nil, errors.New("test error") } func (e *faultyEntity) Mime() string { return "application/json" } type faultyTransport struct { RequestAttempts int Err error } func (e *faultyTransport) RoundTrip(req *http.Request) (*http.Response, error) { e.RequestAttempts++ if e.Err != nil { return nil, e.Err } return nil, errors.New("test error") } func acceptAll(resp *Response) bool { return true } golang-google-firebase-go-4.18.0/internal/internal.go000066400000000000000000000067251505612111400224730ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package internal contains functionality that is only accessible from within the Admin SDK. package internal import ( "time" "golang.org/x/oauth2" "google.golang.org/api/option" ) // FirebaseScopes is the set of OAuth2 scopes used by the Admin SDK. var FirebaseScopes = []string{ "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/datastore", "https://www.googleapis.com/auth/devstorage.full_control", "https://www.googleapis.com/auth/firebase", "https://www.googleapis.com/auth/identitytoolkit", "https://www.googleapis.com/auth/userinfo.email", } // SystemClock is a clock that returns local time of the system. var SystemClock = &systemClock{} // AuthConfig represents the configuration of Firebase Auth service. type AuthConfig struct { Opts []option.ClientOption ProjectID string ServiceAccountID string Version string } // HashConfig represents a hash algorithm configuration used to generate password hashes. type HashConfig map[string]interface{} // InstanceIDConfig represents the configuration of Firebase Instance ID service. type InstanceIDConfig struct { Opts []option.ClientOption ProjectID string Version string } // DatabaseConfig represents the configuration of Firebase Database service. type DatabaseConfig struct { Opts []option.ClientOption URL string Version string AuthOverride map[string]interface{} } // StorageConfig represents the configuration of Google Cloud Storage service. type StorageConfig struct { Opts []option.ClientOption Bucket string } // MessagingConfig represents the configuration of Firebase Cloud Messaging service. type MessagingConfig struct { Opts []option.ClientOption ProjectID string Version string } // RemoteConfigClientConfig represents the configuration of Firebase Remote Config type RemoteConfigClientConfig struct { Opts []option.ClientOption ProjectID string Version string } // AppCheckConfig represents the configuration of App Check service. type AppCheckConfig struct { ProjectID string } // MockTokenSource is a TokenSource implementation that can be used for testing. type MockTokenSource struct { AccessToken string } // Token returns the test token associated with the TokenSource. func (ts *MockTokenSource) Token() (*oauth2.Token, error) { return &oauth2.Token{AccessToken: ts.AccessToken}, nil } // Clock is used to query the current local time. type Clock interface { Now() time.Time } // systemClock returns the current system time. type systemClock struct{} // Now returns the current system time by calling time.Now(). func (s *systemClock) Now() time.Time { return time.Now() } // MockClock can be used to mock current time during tests. type MockClock struct { Timestamp time.Time } // Now returns the timestamp set in the MockClock. func (m *MockClock) Now() time.Time { return m.Timestamp } golang-google-firebase-go-4.18.0/internal/json_http_client_test.go000066400000000000000000000115601505612111400252550ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package internal import ( "context" "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" ) const wantURL = "/test" func TestDoAndUnmarshalGet(t *testing.T) { var req *http.Request handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req = r resp := `{ "name": "test" }` w.Write([]byte(resp)) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: fmt.Sprintf("%s%s", server.URL, wantURL), } var data responseBody resp, err := client.DoAndUnmarshal(context.Background(), get, &data) if err != nil { t.Fatal(err) } if resp.Status != http.StatusOK { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) } if data.Name != "test" { t.Errorf("Data = %v; want = {Name: %q}", data, "test") } if req.Method != http.MethodGet { t.Errorf("Method = %q; want = %q", req.Method, http.MethodGet) } if req.URL.Path != wantURL { t.Errorf("URL = %q; want = %q", req.URL.Path, wantURL) } } func TestDoAndUnmarshalPost(t *testing.T) { var req *http.Request var b []byte handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req = r b, _ = ioutil.ReadAll(r.Body) resp := `{ "name": "test" }` w.Write([]byte(resp)) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } post := &Request{ Method: http.MethodPost, URL: fmt.Sprintf("%s%s", server.URL, wantURL), Body: NewJSONEntity(map[string]string{"input": "test-input"}), } var data responseBody resp, err := client.DoAndUnmarshal(context.Background(), post, &data) if err != nil { t.Fatal(err) } if resp.Status != http.StatusOK { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) } if data.Name != "test" { t.Errorf("Data = %v; want = {Name: %q}", data, "test") } if req.Method != http.MethodPost { t.Errorf("Method = %q; want = %q", req.Method, http.MethodGet) } if req.URL.Path != wantURL { t.Errorf("URL = %q; want = %q", req.URL.Path, wantURL) } var parsed struct { Input string `json:"input"` } if err := json.Unmarshal(b, &parsed); err != nil { t.Fatal(err) } if parsed.Input != "test-input" { t.Errorf("Request Body = %v; want = {Input: %q}", parsed, "test-input") } } func TestDoAndUnmarshalNotJSON(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("not json")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, } var data interface{} wantPrefix := "error while parsing response: " resp, err := client.DoAndUnmarshal(context.Background(), get, &data) if resp != nil || err == nil || !strings.HasPrefix(err.Error(), wantPrefix) { t.Errorf("DoAndUnmarshal() = (%v, %v); want = (nil, %q)", resp, err, wantPrefix) } if data != nil { t.Errorf("Data = %v; want = nil", data) } } func TestDoAndUnmarshalNilPointer(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("not json")) }) server := httptest.NewServer(handler) defer server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, } resp, err := client.DoAndUnmarshal(context.Background(), get, nil) if err != nil { t.Fatalf("DoAndUnmarshal() = %v; want = nil", err) } if resp.Status != http.StatusOK { t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) } } func TestDoAndUnmarshalTransportError(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) server := httptest.NewServer(handler) server.Close() client := &HTTPClient{ Client: http.DefaultClient, } get := &Request{ Method: http.MethodGet, URL: server.URL, } var data interface{} resp, err := client.DoAndUnmarshal(context.Background(), get, &data) if resp != nil || err == nil { t.Errorf("DoAndUnmarshal() = (%v, %v); want = (nil, error)", resp, err) } if data != nil { t.Errorf("Data = %v; want = nil", data) } } type responseBody struct { Name string `json:"name"` } golang-google-firebase-go-4.18.0/messaging/000077500000000000000000000000001505612111400204575ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/messaging/messaging.go000066400000000000000000001112751505612111400227720ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package messaging contains functions for sending messages and managing // device subscriptions with Firebase Cloud Messaging (FCM). package messaging import ( "context" "encoding/json" "errors" "fmt" "net/http" "regexp" "strconv" "strings" "time" "firebase.google.com/go/v4/internal" "google.golang.org/api/transport" ) const ( defaultMessagingEndpoint = "https://fcm.googleapis.com/v1" defaultBatchEndpoint = "https://fcm.googleapis.com/batch" firebaseClientHeader = "X-Firebase-Client" apiFormatVersionHeader = "X-GOOG-API-FORMAT-VERSION" apiFormatVersion = "2" apnsAuthError = "APNS_AUTH_ERROR" internalError = "INTERNAL" thirdPartyAuthError = "THIRD_PARTY_AUTH_ERROR" invalidArgument = "INVALID_ARGUMENT" quotaExceeded = "QUOTA_EXCEEDED" senderIDMismatch = "SENDER_ID_MISMATCH" unregistered = "UNREGISTERED" unavailable = "UNAVAILABLE" rfc3339Zulu = "2006-01-02T15:04:05.000000000Z" ) var ( topicNamePattern = regexp.MustCompile("^(/topics/)?(private/)?[a-zA-Z0-9-_.~%]+$") ) // Message to be sent via Firebase Cloud Messaging. // // Message contains payload data, recipient information and platform-specific configuration // options. A Message must specify exactly one of Token, Topic or Condition fields. Apart from // that a Message may specify any combination of Data, Notification, Android, Webpush and APNS // fields. See https://firebase.google.com/docs/reference/fcm/rest/v1/projects.messages for more // details on how the backend FCM servers handle different message parameters. type Message struct { Data map[string]string `json:"data,omitempty"` Notification *Notification `json:"notification,omitempty"` Android *AndroidConfig `json:"android,omitempty"` Webpush *WebpushConfig `json:"webpush,omitempty"` APNS *APNSConfig `json:"apns,omitempty"` FCMOptions *FCMOptions `json:"fcm_options,omitempty"` Token string `json:"token,omitempty"` Topic string `json:"-"` Condition string `json:"condition,omitempty"` } // MarshalJSON marshals a Message into JSON (for internal use only). func (m *Message) MarshalJSON() ([]byte, error) { // Create a new type to prevent infinite recursion. We use this technique whenever it is needed // to customize how a subset of the fields in a struct should be serialized. type messageInternal Message temp := &struct { BareTopic string `json:"topic,omitempty"` *messageInternal }{ BareTopic: strings.TrimPrefix(m.Topic, "/topics/"), messageInternal: (*messageInternal)(m), } return json.Marshal(temp) } // UnmarshalJSON unmarshals a JSON string into a Message (for internal use only). func (m *Message) UnmarshalJSON(b []byte) error { type messageInternal Message s := struct { BareTopic string `json:"topic,omitempty"` *messageInternal }{ messageInternal: (*messageInternal)(m), } if err := json.Unmarshal(b, &s); err != nil { return err } m.Topic = s.BareTopic return nil } // Notification is the basic notification template to use across all platforms. type Notification struct { Title string `json:"title,omitempty"` Body string `json:"body,omitempty"` ImageURL string `json:"image,omitempty"` } // AndroidConfig contains messaging options specific to the Android platform. type AndroidConfig struct { CollapseKey string `json:"collapse_key,omitempty"` Priority string `json:"priority,omitempty"` // one of "normal" or "high" TTL *time.Duration `json:"-"` RestrictedPackageName string `json:"restricted_package_name,omitempty"` Data map[string]string `json:"data,omitempty"` // if specified, overrides the Data field on Message type Notification *AndroidNotification `json:"notification,omitempty"` FCMOptions *AndroidFCMOptions `json:"fcm_options,omitempty"` DirectBootOK bool `json:"direct_boot_ok,omitempty"` } // MarshalJSON marshals an AndroidConfig into JSON (for internal use only). func (a *AndroidConfig) MarshalJSON() ([]byte, error) { var ttl string if a.TTL != nil { ttl = durationToString(*a.TTL) } type androidInternal AndroidConfig temp := &struct { TTL string `json:"ttl,omitempty"` *androidInternal }{ TTL: ttl, androidInternal: (*androidInternal)(a), } return json.Marshal(temp) } // UnmarshalJSON unmarshals a JSON string into an AndroidConfig (for internal use only). func (a *AndroidConfig) UnmarshalJSON(b []byte) error { type androidInternal AndroidConfig temp := struct { TTL string `json:"ttl,omitempty"` *androidInternal }{ androidInternal: (*androidInternal)(a), } if err := json.Unmarshal(b, &temp); err != nil { return err } if temp.TTL != "" { ttl, err := stringToDuration(temp.TTL) if err != nil { return err } a.TTL = &ttl } return nil } // AndroidNotification is a notification to send to Android devices. type AndroidNotification struct { Title string `json:"title,omitempty"` // if specified, overrides the Title field of the Notification type Body string `json:"body,omitempty"` // if specified, overrides the Body field of the Notification type Icon string `json:"icon,omitempty"` Color string `json:"color,omitempty"` // notification color in #RRGGBB format Sound string `json:"sound,omitempty"` Tag string `json:"tag,omitempty"` ClickAction string `json:"click_action,omitempty"` BodyLocKey string `json:"body_loc_key,omitempty"` BodyLocArgs []string `json:"body_loc_args,omitempty"` TitleLocKey string `json:"title_loc_key,omitempty"` TitleLocArgs []string `json:"title_loc_args,omitempty"` ChannelID string `json:"channel_id,omitempty"` ImageURL string `json:"image,omitempty"` Ticker string `json:"ticker,omitempty"` Sticky bool `json:"sticky,omitempty"` EventTimestamp *time.Time `json:"-"` LocalOnly bool `json:"local_only,omitempty"` Priority AndroidNotificationPriority `json:"-"` VibrateTimingMillis []int64 `json:"-"` DefaultVibrateTimings bool `json:"default_vibrate_timings,omitempty"` DefaultSound bool `json:"default_sound,omitempty"` LightSettings *LightSettings `json:"light_settings,omitempty"` DefaultLightSettings bool `json:"default_light_settings,omitempty"` Visibility AndroidNotificationVisibility `json:"-"` NotificationCount *int `json:"notification_count,omitempty"` Proxy AndroidNotificationProxy `json:"-"` } // MarshalJSON marshals an AndroidNotification into JSON (for internal use only). func (a *AndroidNotification) MarshalJSON() ([]byte, error) { var priority string if a.Priority != priorityUnspecified { priorities := map[AndroidNotificationPriority]string{ PriorityMin: "PRIORITY_MIN", PriorityLow: "PRIORITY_LOW", PriorityDefault: "PRIORITY_DEFAULT", PriorityHigh: "PRIORITY_HIGH", PriorityMax: "PRIORITY_MAX", } priority, _ = priorities[a.Priority] } var visibility string if a.Visibility != visibilityUnspecified { visibilities := map[AndroidNotificationVisibility]string{ VisibilityPrivate: "PRIVATE", VisibilityPublic: "PUBLIC", VisibilitySecret: "SECRET", } visibility, _ = visibilities[a.Visibility] } var proxy string if a.Proxy != proxyUnspecified { proxies := map[AndroidNotificationProxy]string{ ProxyAllow: "ALLOW", ProxyDeny: "DENY", ProxyIfPriorityLowered: "IF_PRIORITY_LOWERED", } proxy, _ = proxies[a.Proxy] } var timestamp string if a.EventTimestamp != nil { timestamp = a.EventTimestamp.UTC().Format(rfc3339Zulu) } var vibTimings []string for _, t := range a.VibrateTimingMillis { vibTimings = append(vibTimings, durationToString(time.Duration(t)*time.Millisecond)) } type androidInternal AndroidNotification temp := &struct { EventTimestamp string `json:"event_time,omitempty"` Priority string `json:"notification_priority,omitempty"` Visibility string `json:"visibility,omitempty"` Proxy string `json:"proxy,omitempty"` VibrateTimings []string `json:"vibrate_timings,omitempty"` *androidInternal }{ EventTimestamp: timestamp, Priority: priority, Visibility: visibility, Proxy: proxy, VibrateTimings: vibTimings, androidInternal: (*androidInternal)(a), } return json.Marshal(temp) } // UnmarshalJSON unmarshals a JSON string into an AndroidNotification (for internal use only). func (a *AndroidNotification) UnmarshalJSON(b []byte) error { type androidInternal AndroidNotification temp := struct { EventTimestamp string `json:"event_time,omitempty"` Priority string `json:"notification_priority,omitempty"` Visibility string `json:"visibility,omitempty"` Proxy string `json:"proxy,omitempty"` VibrateTimings []string `json:"vibrate_timings,omitempty"` *androidInternal }{ androidInternal: (*androidInternal)(a), } if err := json.Unmarshal(b, &temp); err != nil { return err } if temp.Priority != "" { priorities := map[string]AndroidNotificationPriority{ "PRIORITY_MIN": PriorityMin, "PRIORITY_LOW": PriorityLow, "PRIORITY_DEFAULT": PriorityDefault, "PRIORITY_HIGH": PriorityHigh, "PRIORITY_MAX": PriorityMax, } if prio, ok := priorities[temp.Priority]; ok { a.Priority = prio } else { return fmt.Errorf("unknown priority value: %q", temp.Priority) } } if temp.Visibility != "" { visibilities := map[string]AndroidNotificationVisibility{ "PRIVATE": VisibilityPrivate, "PUBLIC": VisibilityPublic, "SECRET": VisibilitySecret, } if vis, ok := visibilities[temp.Visibility]; ok { a.Visibility = vis } else { return fmt.Errorf("unknown visibility value: %q", temp.Visibility) } } if temp.Proxy != "" { proxies := map[string]AndroidNotificationProxy{ "ALLOW": ProxyAllow, "DENY": ProxyDeny, "IF_PRIORITY_LOWERED": ProxyIfPriorityLowered, } if prox, ok := proxies[temp.Proxy]; ok { a.Proxy = prox } else { return fmt.Errorf("unknown proxy value: %q", temp.Proxy) } } if temp.EventTimestamp != "" { ts, err := time.Parse(rfc3339Zulu, temp.EventTimestamp) if err != nil { return err } a.EventTimestamp = &ts } var vibTimings []int64 for _, t := range temp.VibrateTimings { vibTime, err := stringToDuration(t) if err != nil { return err } millis := int64(vibTime / time.Millisecond) vibTimings = append(vibTimings, millis) } a.VibrateTimingMillis = vibTimings return nil } // AndroidNotificationPriority represents the priority levels of a notification. type AndroidNotificationPriority int const ( priorityUnspecified AndroidNotificationPriority = iota // PriorityMin is the lowest notification priority. Notifications with this priority might not // be shown to the user except under special circumstances, such as detailed notification logs. PriorityMin // PriorityLow is a lower notification priority. The UI may choose to show the notifications // smaller, or at a different position in the list, compared with notifications with PriorityDefault. PriorityLow // PriorityDefault is the default notification priority. If the application does not prioritize // its own notifications, use this value for all notifications. PriorityDefault // PriorityHigh is a higher notification priority. Use this for more important // notifications or alerts. The UI may choose to show these notifications larger, or at a // different position in the notification lists, compared with notifications with PriorityDefault. PriorityHigh // PriorityMax is the highest notification priority. Use this for the application's most // important items that require the user's prompt attention or input. PriorityMax ) // AndroidNotificationVisibility represents the different visibility levels of a notification. type AndroidNotificationVisibility int const ( visibilityUnspecified AndroidNotificationVisibility = iota // VisibilityPrivate shows this notification on all lockscreens, but conceal sensitive or // private information on secure lockscreens. VisibilityPrivate // VisibilityPublic shows this notification in its entirety on all lockscreens. VisibilityPublic // VisibilitySecret does not reveal any part of this notification on a secure lockscreen. VisibilitySecret ) // AndroidNotificationProxy to control when a notification may be proxied. type AndroidNotificationProxy int const ( proxyUnspecified AndroidNotificationProxy = iota // ProxyAllow tries to proxy this notification. ProxyAllow // ProxyDeny does not proxy this notification. ProxyDeny // ProxyIfPriorityLowered only tries to proxy this notification if its AndroidConfig's Priority was // lowered from high to normal on the device. ProxyIfPriorityLowered ) // LightSettings to control notification LED. type LightSettings struct { Color string LightOnDurationMillis int64 LightOffDurationMillis int64 } // MarshalJSON marshals an LightSettings into JSON (for internal use only). func (l *LightSettings) MarshalJSON() ([]byte, error) { clr, err := newColor(l.Color) if err != nil { return nil, err } temp := struct { Color *color `json:"color"` LightOnDuration string `json:"light_on_duration"` LightOffDuration string `json:"light_off_duration"` }{ Color: clr, LightOnDuration: durationToString(time.Duration(l.LightOnDurationMillis) * time.Millisecond), LightOffDuration: durationToString(time.Duration(l.LightOffDurationMillis) * time.Millisecond), } return json.Marshal(temp) } // UnmarshalJSON unmarshals a JSON string into an LightSettings (for internal use only). func (l *LightSettings) UnmarshalJSON(b []byte) error { temp := struct { Color *color `json:"color"` LightOnDuration string `json:"light_on_duration"` LightOffDuration string `json:"light_off_duration"` }{} if err := json.Unmarshal(b, &temp); err != nil { return err } on, err := stringToDuration(temp.LightOnDuration) if err != nil { return err } off, err := stringToDuration(temp.LightOffDuration) if err != nil { return err } l.Color = temp.Color.toString() l.LightOnDurationMillis = int64(on / time.Millisecond) l.LightOffDurationMillis = int64(off / time.Millisecond) return nil } func durationToString(ms time.Duration) string { seconds := int64(ms / time.Second) nanos := int64((ms - time.Duration(seconds)*time.Second) / time.Nanosecond) if nanos > 0 { return fmt.Sprintf("%d.%09ds", seconds, nanos) } return fmt.Sprintf("%ds", seconds) } func stringToDuration(s string) (time.Duration, error) { segments := strings.Split(strings.TrimSuffix(s, "s"), ".") if len(segments) != 1 && len(segments) != 2 { return 0, fmt.Errorf("incorrect number of segments in ttl: %q", s) } seconds, err := strconv.ParseInt(segments[0], 10, 64) if err != nil { return 0, fmt.Errorf("failed to parse %s: %v", s, err) } ttl := time.Duration(seconds) * time.Second if len(segments) == 2 { nanos, err := strconv.ParseInt(strings.TrimLeft(segments[1], "0"), 10, 64) if err != nil { return 0, fmt.Errorf("failed to parse %s: %v", s, err) } ttl += time.Duration(nanos) * time.Nanosecond } return ttl, nil } type color struct { Red float64 `json:"red"` Green float64 `json:"green"` Blue float64 `json:"blue"` Alpha float64 `json:"alpha"` } func newColor(clr string) (*color, error) { red, err := strconv.ParseInt(clr[1:3], 16, 32) if err != nil { return nil, fmt.Errorf("failed to parse %s: %v", clr, err) } green, err := strconv.ParseInt(clr[3:5], 16, 32) if err != nil { return nil, fmt.Errorf("failed to parse %s: %v", clr, err) } blue, err := strconv.ParseInt(clr[5:7], 16, 32) if err != nil { return nil, fmt.Errorf("failed to parse %s: %v", clr, err) } alpha := int64(255) if len(clr) == 9 { alpha, err = strconv.ParseInt(clr[7:9], 16, 32) if err != nil { return nil, fmt.Errorf("failed to parse %s: %v", clr, err) } } return &color{ Red: float64(red) / 255.0, Green: float64(green) / 255.0, Blue: float64(blue) / 255.0, Alpha: float64(alpha) / 255.0, }, nil } func (c *color) toString() string { red := int(c.Red * 255.0) green := int(c.Green * 255.0) blue := int(c.Blue * 255.0) alpha := int(c.Alpha * 255.0) if alpha == 255 { return fmt.Sprintf("#%X%X%X", red, green, blue) } return fmt.Sprintf("#%X%X%X%X", red, green, blue, alpha) } // AndroidFCMOptions contains additional options for features provided by the FCM Android SDK. type AndroidFCMOptions struct { AnalyticsLabel string `json:"analytics_label,omitempty"` } // WebpushConfig contains messaging options specific to the WebPush protocol. // // See https://tools.ietf.org/html/rfc8030#section-5 for additional details, and supported // headers. type WebpushConfig struct { Headers map[string]string `json:"headers,omitempty"` Data map[string]string `json:"data,omitempty"` Notification *WebpushNotification `json:"notification,omitempty"` FCMOptions *WebpushFCMOptions `json:"fcm_options,omitempty"` } // WebpushNotificationAction represents an action that can be performed upon receiving a WebPush notification. type WebpushNotificationAction struct { Action string `json:"action,omitempty"` Title string `json:"title,omitempty"` Icon string `json:"icon,omitempty"` } // WebpushNotification is a notification to send via WebPush protocol. // // See https://developer.mozilla.org/en-US/docs/Web/API/notification/Notification for additional // details. type WebpushNotification struct { Actions []*WebpushNotificationAction `json:"actions,omitempty"` Title string `json:"title,omitempty"` // if specified, overrides the Title field of the Notification type Body string `json:"body,omitempty"` // if specified, overrides the Body field of the Notification type Icon string `json:"icon,omitempty"` Badge string `json:"badge,omitempty"` Direction string `json:"dir,omitempty"` // one of 'ltr' or 'rtl' Data interface{} `json:"data,omitempty"` Image string `json:"image,omitempty"` Language string `json:"lang,omitempty"` Renotify bool `json:"renotify,omitempty"` RequireInteraction bool `json:"requireInteraction,omitempty"` Silent bool `json:"silent,omitempty"` Tag string `json:"tag,omitempty"` TimestampMillis *int64 `json:"timestamp,omitempty"` Vibrate []int `json:"vibrate,omitempty"` CustomData map[string]interface{} } // standardFields creates a map containing all the fields except the custom data. // // We implement a standardFields function whenever we want to add custom and arbitrary // fields to an object during its serialization. This helper function also comes in // handy during validation of the message (to detect duplicate specifications of // fields), and also during deserialization. func (n *WebpushNotification) standardFields() map[string]interface{} { m := make(map[string]interface{}) addNonEmpty := func(key, value string) { if value != "" { m[key] = value } } addTrue := func(key string, value bool) { if value { m[key] = value } } if len(n.Actions) > 0 { m["actions"] = n.Actions } addNonEmpty("title", n.Title) addNonEmpty("body", n.Body) addNonEmpty("icon", n.Icon) addNonEmpty("badge", n.Badge) addNonEmpty("dir", n.Direction) addNonEmpty("image", n.Image) addNonEmpty("lang", n.Language) addTrue("renotify", n.Renotify) addTrue("requireInteraction", n.RequireInteraction) addTrue("silent", n.Silent) addNonEmpty("tag", n.Tag) if n.Data != nil { m["data"] = n.Data } if n.TimestampMillis != nil { m["timestamp"] = *n.TimestampMillis } if len(n.Vibrate) > 0 { m["vibrate"] = n.Vibrate } return m } // MarshalJSON marshals a WebpushNotification into JSON (for internal use only). func (n *WebpushNotification) MarshalJSON() ([]byte, error) { m := n.standardFields() for k, v := range n.CustomData { m[k] = v } return json.Marshal(m) } // UnmarshalJSON unmarshals a JSON string into a WebpushNotification (for internal use only). func (n *WebpushNotification) UnmarshalJSON(b []byte) error { type webpushNotificationInternal WebpushNotification var temp = (*webpushNotificationInternal)(n) if err := json.Unmarshal(b, temp); err != nil { return err } allFields := make(map[string]interface{}) if err := json.Unmarshal(b, &allFields); err != nil { return err } for k := range n.standardFields() { delete(allFields, k) } if len(allFields) > 0 { n.CustomData = allFields } return nil } // WebpushFCMOptions contains additional options for features provided by the FCM web SDK. type WebpushFCMOptions struct { Link string `json:"link,omitempty"` } // APNSConfig contains messaging options specific to the Apple Push Notification Service (APNS). // // See https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html // for more details on supported headers and payload keys. type APNSConfig struct { Headers map[string]string `json:"headers,omitempty"` Payload *APNSPayload `json:"payload,omitempty"` FCMOptions *APNSFCMOptions `json:"fcm_options,omitempty"` LiveActivityToken string `json:"live_activity_token,omitempty"` } // APNSPayload is the payload that can be included in an APNS message. // // The payload mainly consists of the aps dictionary. Additionally it may contain arbitrary // key-values pairs as custom data fields. // // See https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/PayloadKeyReference.html // for a full list of supported payload fields. type APNSPayload struct { Aps *Aps `json:"aps,omitempty"` CustomData map[string]interface{} `json:"-"` } // standardFields creates a map containing all the fields except the custom data. func (p *APNSPayload) standardFields() map[string]interface{} { return map[string]interface{}{"aps": p.Aps} } // MarshalJSON marshals an APNSPayload into JSON (for internal use only). func (p *APNSPayload) MarshalJSON() ([]byte, error) { m := p.standardFields() for k, v := range p.CustomData { m[k] = v } return json.Marshal(m) } // UnmarshalJSON unmarshals a JSON string into an APNSPayload (for internal use only). func (p *APNSPayload) UnmarshalJSON(b []byte) error { type apnsPayloadInternal APNSPayload var temp = (*apnsPayloadInternal)(p) if err := json.Unmarshal(b, temp); err != nil { return err } allFields := make(map[string]interface{}) if err := json.Unmarshal(b, &allFields); err != nil { return err } for k := range p.standardFields() { delete(allFields, k) } if len(allFields) > 0 { p.CustomData = allFields } return nil } // Aps represents the aps dictionary that may be included in an APNSPayload. // // Alert may be specified as a string (via the AlertString field), or as a struct (via the Alert // field). type Aps struct { AlertString string `json:"-"` Alert *ApsAlert `json:"-"` Badge *int `json:"badge,omitempty"` Sound string `json:"-"` CriticalSound *CriticalSound `json:"-"` ContentAvailable bool `json:"-"` MutableContent bool `json:"-"` Category string `json:"category,omitempty"` ThreadID string `json:"thread-id,omitempty"` CustomData map[string]interface{} `json:"-"` } // standardFields creates a map containing all the fields except the custom data. func (a *Aps) standardFields() map[string]interface{} { m := make(map[string]interface{}) if a.Alert != nil { m["alert"] = a.Alert } else if a.AlertString != "" { m["alert"] = a.AlertString } if a.ContentAvailable { m["content-available"] = 1 } if a.MutableContent { m["mutable-content"] = 1 } if a.Badge != nil { m["badge"] = *a.Badge } if a.CriticalSound != nil { m["sound"] = a.CriticalSound } else if a.Sound != "" { m["sound"] = a.Sound } if a.Category != "" { m["category"] = a.Category } if a.ThreadID != "" { m["thread-id"] = a.ThreadID } return m } // MarshalJSON marshals an Aps into JSON (for internal use only). func (a *Aps) MarshalJSON() ([]byte, error) { m := a.standardFields() for k, v := range a.CustomData { m[k] = v } return json.Marshal(m) } // UnmarshalJSON unmarshals a JSON string into an Aps (for internal use only). func (a *Aps) UnmarshalJSON(b []byte) error { type apsInternal Aps temp := struct { AlertObject *json.RawMessage `json:"alert,omitempty"` SoundObject *json.RawMessage `json:"sound,omitempty"` ContentAvailableInt int `json:"content-available,omitempty"` MutableContentInt int `json:"mutable-content,omitempty"` *apsInternal }{ apsInternal: (*apsInternal)(a), } if err := json.Unmarshal(b, &temp); err != nil { return err } a.ContentAvailable = (temp.ContentAvailableInt == 1) a.MutableContent = (temp.MutableContentInt == 1) if temp.AlertObject != nil { if err := json.Unmarshal(*temp.AlertObject, &a.Alert); err != nil { a.Alert = nil if err := json.Unmarshal(*temp.AlertObject, &a.AlertString); err != nil { return fmt.Errorf("failed to unmarshal alert as a struct or a string: %v", err) } } } if temp.SoundObject != nil { if err := json.Unmarshal(*temp.SoundObject, &a.CriticalSound); err != nil { a.CriticalSound = nil if err := json.Unmarshal(*temp.SoundObject, &a.Sound); err != nil { return fmt.Errorf("failed to unmarshal sound as a struct or a string") } } } allFields := make(map[string]interface{}) if err := json.Unmarshal(b, &allFields); err != nil { return err } for k := range a.standardFields() { delete(allFields, k) } if len(allFields) > 0 { a.CustomData = allFields } return nil } // CriticalSound is the sound payload that can be included in an Aps. type CriticalSound struct { Critical bool `json:"-"` Name string `json:"name,omitempty"` Volume float64 `json:"volume,omitempty"` } // MarshalJSON marshals a CriticalSound into JSON (for internal use only). func (cs *CriticalSound) MarshalJSON() ([]byte, error) { type criticalSoundInternal CriticalSound temp := struct { CriticalInt int `json:"critical,omitempty"` *criticalSoundInternal }{ criticalSoundInternal: (*criticalSoundInternal)(cs), } if cs.Critical { temp.CriticalInt = 1 } return json.Marshal(temp) } // UnmarshalJSON unmarshals a JSON string into a CriticalSound (for internal use only). func (cs *CriticalSound) UnmarshalJSON(b []byte) error { type criticalSoundInternal CriticalSound temp := struct { CriticalInt int `json:"critical,omitempty"` *criticalSoundInternal }{ criticalSoundInternal: (*criticalSoundInternal)(cs), } if err := json.Unmarshal(b, &temp); err != nil { return err } cs.Critical = (temp.CriticalInt == 1) return nil } // ApsAlert is the alert payload that can be included in an Aps. // // See https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/PayloadKeyReference.html // for supported fields. type ApsAlert struct { Title string `json:"title,omitempty"` // if specified, overrides the Title field of the Notification type SubTitle string `json:"subtitle,omitempty"` Body string `json:"body,omitempty"` // if specified, overrides the Body field of the Notification type LocKey string `json:"loc-key,omitempty"` LocArgs []string `json:"loc-args,omitempty"` TitleLocKey string `json:"title-loc-key,omitempty"` TitleLocArgs []string `json:"title-loc-args,omitempty"` SubTitleLocKey string `json:"subtitle-loc-key,omitempty"` SubTitleLocArgs []string `json:"subtitle-loc-args,omitempty"` ActionLocKey string `json:"action-loc-key,omitempty"` LaunchImage string `json:"launch-image,omitempty"` } // APNSFCMOptions contains additional options for features provided by the FCM Aps SDK. type APNSFCMOptions struct { AnalyticsLabel string `json:"analytics_label,omitempty"` ImageURL string `json:"image,omitempty"` } // FCMOptions contains additional options to use across all platforms. type FCMOptions struct { AnalyticsLabel string `json:"analytics_label,omitempty"` } // ErrorInfo is a topic management error. type ErrorInfo struct { Index int Reason string } // Client is the interface for the Firebase Cloud Messaging (FCM) service. type Client struct { *fcmClient *iidClient } // NewClient creates a new instance of the Firebase Cloud Messaging Client. // // This function can only be invoked from within the SDK. Client applications should access the // the messaging service through firebase.App. func NewClient(ctx context.Context, c *internal.MessagingConfig) (*Client, error) { if c.ProjectID == "" { return nil, errors.New("project ID is required to access Firebase Cloud Messaging client") } hc, messagingEndpoint, err := transport.NewHTTPClient(ctx, c.Opts...) if err != nil { return nil, err } batchEndpoint := messagingEndpoint if messagingEndpoint == "" { messagingEndpoint = defaultMessagingEndpoint batchEndpoint = defaultBatchEndpoint } return &Client{ fcmClient: newFCMClient(hc, c, messagingEndpoint, batchEndpoint), iidClient: newIIDClient(hc, c), }, nil } type fcmClient struct { fcmEndpoint string batchEndpoint string project string version string httpClient *internal.HTTPClient } func newFCMClient(hc *http.Client, conf *internal.MessagingConfig, messagingEndpoint string, batchEndpoint string) *fcmClient { client := internal.WithDefaultRetryConfig(hc) client.CreateErrFn = handleFCMError version := fmt.Sprintf("fire-admin-go/%s", conf.Version) client.Opts = []internal.HTTPOption{ internal.WithHeader(apiFormatVersionHeader, apiFormatVersion), internal.WithHeader(firebaseClientHeader, version), internal.WithHeader("x-goog-api-client", internal.GetMetricsHeader(conf.Version)), } return &fcmClient{ fcmEndpoint: messagingEndpoint, batchEndpoint: batchEndpoint, project: conf.ProjectID, version: version, httpClient: client, } } // Send sends a Message to Firebase Cloud Messaging. // // The Message must specify exactly one of Token, Topic and Condition fields. FCM will // customize the message for each target platform based on the arguments specified in the // Message. func (c *fcmClient) Send(ctx context.Context, message *Message) (string, error) { payload := &fcmRequest{ Message: message, } return c.makeSendRequest(ctx, payload) } // SendDryRun sends a Message to Firebase Cloud Messaging in the dry run (validation only) mode. // // This function does not actually deliver the message to target devices. Instead, it performs all // the SDK-level and backend validations on the message, and emulates the send operation. func (c *fcmClient) SendDryRun(ctx context.Context, message *Message) (string, error) { payload := &fcmRequest{ ValidateOnly: true, Message: message, } return c.makeSendRequest(ctx, payload) } func (c *fcmClient) makeSendRequest(ctx context.Context, req *fcmRequest) (string, error) { if err := validateMessage(req.Message); err != nil { return "", err } request := &internal.Request{ Method: http.MethodPost, URL: fmt.Sprintf("%s/projects/%s/messages:send", c.fcmEndpoint, c.project), Body: internal.NewJSONEntity(req), } var result fcmResponse _, err := c.httpClient.DoAndUnmarshal(ctx, request, &result) return result.Name, err } // IsInternal checks if the given error was due to an internal server error. func IsInternal(err error) bool { return hasMessagingErrorCode(err, internalError) } // IsInvalidAPNSCredentials checks if the given error was due to invalid APNS certificate or auth // key. // // Deprecated: Use IsThirdPartyAuthError(). func IsInvalidAPNSCredentials(err error) bool { return IsThirdPartyAuthError(err) } // IsThirdPartyAuthError checks if the given error was due to invalid APNS certificate or auth // key. func IsThirdPartyAuthError(err error) bool { return hasMessagingErrorCode(err, thirdPartyAuthError) || hasMessagingErrorCode(err, apnsAuthError) } // IsInvalidArgument checks if the given error was due to an invalid argument in the request. func IsInvalidArgument(err error) bool { return hasMessagingErrorCode(err, invalidArgument) } // IsMessageRateExceeded checks if the given error was due to the client exceeding a quota. // // Deprecated: Use IsQuotaExceeded(). func IsMessageRateExceeded(err error) bool { return IsQuotaExceeded(err) } // IsQuotaExceeded checks if the given error was due to the client exceeding a quota. func IsQuotaExceeded(err error) bool { return hasMessagingErrorCode(err, quotaExceeded) } // IsMismatchedCredential checks if the given error was due to an invalid credential or permission // error. // // Deprecated: Use IsSenderIDMismatch(). func IsMismatchedCredential(err error) bool { return IsSenderIDMismatch(err) } // IsSenderIDMismatch checks if the given error was due to an invalid credential or permission // error. func IsSenderIDMismatch(err error) bool { return hasMessagingErrorCode(err, senderIDMismatch) } // IsRegistrationTokenNotRegistered checks if the given error was due to a registration token that // became invalid. // // Deprecated: Use IsUnregistered(). func IsRegistrationTokenNotRegistered(err error) bool { return IsUnregistered(err) } // IsUnregistered checks if the given error was due to a registration token that // became invalid. func IsUnregistered(err error) bool { return hasMessagingErrorCode(err, unregistered) } // IsServerUnavailable checks if the given error was due to the backend server being temporarily // unavailable. // // Deprecated: Use IsUnavailable(). func IsServerUnavailable(err error) bool { return IsUnavailable(err) } // IsUnavailable checks if the given error was due to the backend server being temporarily // unavailable. func IsUnavailable(err error) bool { return hasMessagingErrorCode(err, unavailable) } // IsTooManyTopics checks if the given error was due to the client exceeding the allowed number // of topics. // // Deprecated: Always returns false. func IsTooManyTopics(err error) bool { return false } // IsUnknown checks if the given error was due to unknown error returned by the backend server. // // Deprecated: Always returns false. func IsUnknown(err error) bool { return false } type fcmRequest struct { ValidateOnly bool `json:"validate_only,omitempty"` Message *Message `json:"message,omitempty"` } type fcmResponse struct { Name string `json:"name"` } type fcmErrorResponse struct { Error struct { Details []struct { Type string `json:"@type"` ErrorCode string `json:"errorCode"` } } `json:"error"` } func handleFCMError(resp *internal.Response) error { base := internal.NewFirebaseErrorOnePlatform(resp) var fe fcmErrorResponse json.Unmarshal(resp.Body, &fe) // ignore any json parse errors at this level for _, d := range fe.Error.Details { if d.Type == "type.googleapis.com/google.firebase.fcm.v1.FcmError" { base.Ext["messagingErrorCode"] = d.ErrorCode break } } return base } func hasMessagingErrorCode(err error, code string) bool { fe, ok := err.(*internal.FirebaseError) if !ok { return false } got, ok := fe.Ext["messagingErrorCode"] return ok && got == code } golang-google-firebase-go-4.18.0/messaging/messaging_batch.go000066400000000000000000000406131505612111400241300ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "io/ioutil" "mime" "mime/multipart" "net/http" "net/textproto" "firebase.google.com/go/v4/internal" ) const maxMessages = 500 const multipartBoundary = "__END_OF_PART__" // MulticastMessage represents a message that can be sent to multiple devices via Firebase Cloud // Messaging (FCM). // // It contains payload information as well as the list of device registration tokens to which the // message should be sent. A single MulticastMessage may contain up to 500 registration tokens. type MulticastMessage struct { Tokens []string Data map[string]string Notification *Notification Android *AndroidConfig Webpush *WebpushConfig APNS *APNSConfig FCMOptions *FCMOptions } func (mm *MulticastMessage) toMessages() ([]*Message, error) { if len(mm.Tokens) == 0 { return nil, errors.New("tokens must not be nil or empty") } if len(mm.Tokens) > maxMessages { return nil, fmt.Errorf("tokens must not contain more than %d elements", maxMessages) } var messages []*Message for _, token := range mm.Tokens { temp := &Message{ Token: token, Data: mm.Data, Notification: mm.Notification, Android: mm.Android, Webpush: mm.Webpush, APNS: mm.APNS, FCMOptions: mm.FCMOptions, } messages = append(messages, temp) } return messages, nil } // SendResponse represents the status of an individual message that was sent as part of a batch // request. type SendResponse struct { Success bool MessageID string Error error } // BatchResponse represents the response from the SendAll() and SendMulticast() APIs. type BatchResponse struct { SuccessCount int FailureCount int Responses []*SendResponse } // SendEach sends the messages in the given array via Firebase Cloud Messaging. // // The messages array may contain up to 500 messages. Unlike SendAll(), SendEach sends the entire // array of messages by making a single HTTP call for each message. The responses list // obtained from the return value corresponds to the order of the input messages. An error // from SendEach or a BatchResponse with all failures indicates a total failure, meaning that // none of the messages in the list could be sent. Partial failures or no failures are only // indicated by a BatchResponse return value. func (c *fcmClient) SendEach(ctx context.Context, messages []*Message) (*BatchResponse, error) { return c.sendEachInBatch(ctx, messages, false) } // SendEachDryRun sends the messages in the given array via Firebase Cloud Messaging in the // dry run (validation only) mode. // // This function does not actually deliver any messages to target devices. Instead, it performs all // the SDK-level and backend validations on the messages, and emulates the send operation. // // The messages array may contain up to 500 messages. Unlike SendAllDryRun(), SendEachDryRun sends // the entire array of messages by making a single HTTP call for each message. The responses list // obtained from the return value corresponds to the order of the input messages. An error // from SendEachDryRun or a BatchResponse with all failures indicates a total failure, meaning // that none of the messages in the list could be sent. Partial failures or no failures are only // indicated by a BatchResponse return value. func (c *fcmClient) SendEachDryRun(ctx context.Context, messages []*Message) (*BatchResponse, error) { return c.sendEachInBatch(ctx, messages, true) } // SendEachForMulticast sends the given multicast message to all the FCM registration tokens specified. // // The tokens array in MulticastMessage may contain up to 500 tokens. SendMulticast uses the // SendEach() function to send the given message to all the target recipients. The // responses list obtained from the return value corresponds to the order of the input tokens. An error // from SendEachForMulticast or a BatchResponse with all failures indicates a total failure, meaning // that none of the messages in the list could be sent. Partial failures or no failures are only // indicated by a BatchResponse return value. func (c *fcmClient) SendEachForMulticast(ctx context.Context, message *MulticastMessage) (*BatchResponse, error) { messages, err := toMessages(message) if err != nil { return nil, err } return c.SendEach(ctx, messages) } // SendEachForMulticastDryRun sends the given multicast message to all the specified FCM registration // tokens in the dry run (validation only) mode. // // This function does not actually deliver any messages to target devices. Instead, it performs all // the SDK-level and backend validations on the messages, and emulates the send operation. // // The tokens array in MulticastMessage may contain up to 500 tokens. SendEachForMulticastDryRunn uses the // SendEachDryRun() function to send the given message. The responses list obtained from // the return value corresponds to the order of the input tokens. An error from SendEachForMulticastDryRun // or a BatchResponse with all failures indicates a total failure, meaning that of the messages in the // list could be sent. Partial failures or no failures are only // indicated by a BatchResponse return value. func (c *fcmClient) SendEachForMulticastDryRun(ctx context.Context, message *MulticastMessage) (*BatchResponse, error) { messages, err := toMessages(message) if err != nil { return nil, err } return c.SendEachDryRun(ctx, messages) } func (c *fcmClient) sendEachInBatch(ctx context.Context, messages []*Message, dryRun bool) (*BatchResponse, error) { if len(messages) == 0 { return nil, errors.New("messages must not be nil or empty") } if len(messages) > maxMessages { return nil, fmt.Errorf("messages must not contain more than %d elements", maxMessages) } for idx, m := range messages { if err := validateMessage(m); err != nil { return nil, fmt.Errorf("invalid message at index %d: %v", idx, err) } } const numWorkers = 50 jobs := make(chan job, len(messages)) results := make(chan result, len(messages)) responses := make([]*SendResponse, len(messages)) for w := 0; w < numWorkers; w++ { go worker(ctx, c, dryRun, jobs, results) } for idx, m := range messages { jobs <- job{message: m, index: idx} } close(jobs) for i := 0; i < len(messages); i++ { res := <-results responses[res.index] = res.response } successCount := 0 failureCount := 0 for _, r := range responses { if r.Success { successCount++ } else { failureCount++ } } return &BatchResponse{ Responses: responses, SuccessCount: successCount, FailureCount: failureCount, }, nil } type job struct { message *Message index int } type result struct { response *SendResponse index int } func worker(ctx context.Context, c *fcmClient, dryRun bool, jobs <-chan job, results chan<- result) { for j := range jobs { var respMsg string var err error if dryRun { respMsg, err = c.SendDryRun(ctx, j.message) } else { respMsg, err = c.Send(ctx, j.message) } var sr *SendResponse if err == nil { sr = &SendResponse{ Success: true, MessageID: respMsg, } } else { sr = &SendResponse{ Success: false, Error: err, } } results <- result{response: sr, index: j.index} } } // SendAll sends the messages in the given array via Firebase Cloud Messaging. // // The messages array may contain up to 500 messages. SendAll employs batching to send the entire // array of messages as a single RPC call. Compared to the Send() function, // this is a significantly more efficient way to send multiple messages. The responses list // obtained from the return value corresponds to the order of the input messages. An error from // SendAll indicates a total failure, meaning that none of the messages in the array could be // sent. Partial failures are indicated by a BatchResponse return value. // // Deprecated: Use SendEach instead. func (c *fcmClient) SendAll(ctx context.Context, messages []*Message) (*BatchResponse, error) { return c.sendBatch(ctx, messages, false) } // SendAllDryRun sends the messages in the given array via Firebase Cloud Messaging in the // dry run (validation only) mode. // // This function does not actually deliver any messages to target devices. Instead, it performs all // the SDK-level and backend validations on the messages, and emulates the send operation. // // The messages array may contain up to 500 messages. SendAllDryRun employs batching to send the // entire array of messages as a single RPC call. Compared to the SendDryRun() function, this // is a significantly more efficient way to validate sending multiple messages. The responses list // obtained from the return value corresponds to the order of the input messages. An error from // SendAllDryRun indicates a total failure, meaning that none of the messages in the array could // be sent for validation. Partial failures are indicated by a BatchResponse return value. // // Deprecated: Use SendEachDryRun instead. func (c *fcmClient) SendAllDryRun(ctx context.Context, messages []*Message) (*BatchResponse, error) { return c.sendBatch(ctx, messages, true) } // SendMulticast sends the given multicast message to all the FCM registration tokens specified. // // The tokens array in MulticastMessage may contain up to 500 tokens. SendMulticast uses the // SendAll() function to send the given message to all the target recipients. The // responses list obtained from the return value corresponds to the order of the input tokens. An // error from SendMulticast indicates a total failure, meaning that the message could not be sent // to any of the recipients. Partial failures are indicated by a BatchResponse return value. // // Deprecated: Use SendEachForMulticast instead. func (c *fcmClient) SendMulticast(ctx context.Context, message *MulticastMessage) (*BatchResponse, error) { messages, err := toMessages(message) if err != nil { return nil, err } return c.SendAll(ctx, messages) } // SendMulticastDryRun sends the given multicast message to all the specified FCM registration // tokens in the dry run (validation only) mode. // // This function does not actually deliver any messages to target devices. Instead, it performs all // the SDK-level and backend validations on the messages, and emulates the send operation. // // The tokens array in MulticastMessage may contain up to 500 tokens. SendMulticastDryRun uses the // SendAllDryRun() function to send the given message. The responses list obtained from // the return value corresponds to the order of the input tokens. An error from SendMulticastDryRun // indicates a total failure, meaning that none of the messages were sent to FCM for validation. // Partial failures are indicated by a BatchResponse return value. // // Deprecated: Use SendEachForMulticastDryRun instead. func (c *fcmClient) SendMulticastDryRun(ctx context.Context, message *MulticastMessage) (*BatchResponse, error) { messages, err := toMessages(message) if err != nil { return nil, err } return c.SendAllDryRun(ctx, messages) } func toMessages(message *MulticastMessage) ([]*Message, error) { if message == nil { return nil, errors.New("message must not be nil") } return message.toMessages() } func (c *fcmClient) sendBatch( ctx context.Context, messages []*Message, dryRun bool) (*BatchResponse, error) { if len(messages) == 0 { return nil, errors.New("messages must not be nil or empty") } if len(messages) > maxMessages { return nil, fmt.Errorf("messages must not contain more than %d elements", maxMessages) } request, err := c.newBatchRequest(messages, dryRun) if err != nil { return nil, err } resp, err := c.httpClient.Do(ctx, request) if err != nil { return nil, err } if resp.Status != http.StatusOK { return nil, handleFCMError(resp) } return newBatchResponse(resp) } // part represents a HTTP request that can be sent embedded in a multipart batch request. // // See https://cloud.google.com/compute/docs/api/how-tos/batch for details on how GCP APIs support multipart batch // requests. type part struct { method string url string headers map[string]string body interface{} } // multipartEntity represents an HTTP entity that consists of multiple HTTP requests (parts). type multipartEntity struct { parts []*part } func (c *fcmClient) newBatchRequest(messages []*Message, dryRun bool) (*internal.Request, error) { url := fmt.Sprintf("%s/projects/%s/messages:send", c.fcmEndpoint, c.project) headers := map[string]string{ apiFormatVersionHeader: apiFormatVersion, firebaseClientHeader: c.version, } var parts []*part for idx, m := range messages { if err := validateMessage(m); err != nil { return nil, fmt.Errorf("invalid message at index %d: %v", idx, err) } p := &part{ method: http.MethodPost, url: url, body: &fcmRequest{ Message: m, ValidateOnly: dryRun, }, headers: headers, } parts = append(parts, p) } return &internal.Request{ Method: http.MethodPost, URL: c.batchEndpoint, Body: &multipartEntity{parts: parts}, Opts: []internal.HTTPOption{ internal.WithHeader(firebaseClientHeader, c.version), }, }, nil } func newBatchResponse(resp *internal.Response) (*BatchResponse, error) { _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) if err != nil { return nil, fmt.Errorf("error parsing content-type header: %v", err) } mr := multipart.NewReader(bytes.NewBuffer(resp.Body), params["boundary"]) var responses []*SendResponse successCount := 0 for { part, err := mr.NextPart() if err == io.EOF { break } else if err != nil { return nil, err } sr, err := newSendResponse(part) if err != nil { return nil, err } responses = append(responses, sr) if sr.Success { successCount++ } } return &BatchResponse{ Responses: responses, SuccessCount: successCount, FailureCount: len(responses) - successCount, }, nil } func newSendResponse(part *multipart.Part) (*SendResponse, error) { hr, err := http.ReadResponse(bufio.NewReader(part), nil) if err != nil { return nil, fmt.Errorf("error parsing multipart body: %v", err) } b, err := ioutil.ReadAll(hr.Body) if err != nil { return nil, err } if hr.StatusCode != http.StatusOK { resp := &internal.Response{ Status: hr.StatusCode, Header: hr.Header, Body: b, } return &SendResponse{ Success: false, Error: handleFCMError(resp), }, nil } var result fcmResponse if err := json.Unmarshal(b, &result); err != nil { return nil, err } return &SendResponse{ Success: true, MessageID: result.Name, }, nil } func (e *multipartEntity) Mime() string { return fmt.Sprintf("multipart/mixed; boundary=%s", multipartBoundary) } func (e *multipartEntity) Bytes() ([]byte, error) { var buffer bytes.Buffer writer := multipart.NewWriter(&buffer) writer.SetBoundary(multipartBoundary) for idx, part := range e.parts { if err := part.writeTo(writer, idx); err != nil { return nil, err } } writer.Close() return buffer.Bytes(), nil } func (p *part) writeTo(writer *multipart.Writer, idx int) error { b, err := p.bytes() if err != nil { return err } header := make(textproto.MIMEHeader) header.Add("Content-Length", fmt.Sprintf("%d", len(b))) header.Add("Content-Type", "application/http") header.Add("Content-Id", fmt.Sprintf("%d", idx+1)) header.Add("Content-Transfer-Encoding", "binary") part, err := writer.CreatePart(header) if err != nil { return err } _, err = part.Write(b) return err } func (p *part) bytes() ([]byte, error) { b, err := json.Marshal(p.body) if err != nil { return nil, err } req, err := http.NewRequest(p.method, p.url, bytes.NewBuffer(b)) if err != nil { return nil, err } for key, value := range p.headers { req.Header.Set(key, value) } req.Header.Set("Content-Type", "application/json; charset=UTF-8") req.Header.Set("User-Agent", "") var buffer bytes.Buffer if err := req.Write(&buffer); err != nil { return nil, err } return buffer.Bytes(), nil } golang-google-firebase-go-4.18.0/messaging/messaging_batch_test.go000066400000000000000000001220601505612111400251640ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "net/textproto" "strings" "sync" "testing" "google.golang.org/api/option" ) var testMessages = []*Message{{Topic: "topic1"}, {Topic: "topic2"}} var testMulticastMessage = &MulticastMessage{ Tokens: []string{"token1", "token2"}, } var testSuccessResponse = []fcmResponse{ {Name: "projects/test-project/messages/1"}, {Name: "projects/test-project/messages/2"}, } const wantMime = "multipart/mixed; boundary=__END_OF_PART__" const wantSendURL = "/v1/projects/test-project/messages:send" func TestMultipartEntitySingle(t *testing.T) { entity := &multipartEntity{ parts: []*part{{ method: "POST", url: "http://example.com", body: map[string]interface{}{"key": "value"}, }}, } mime := entity.Mime() if mime != wantMime { t.Errorf("Mime() = %q; want = %q", mime, wantMime) } b, err := entity.Bytes() if err != nil { t.Fatal(err) } want := "--__END_OF_PART__\r\n" + "Content-Id: 1\r\n" + "Content-Length: 120\r\n" + "Content-Transfer-Encoding: binary\r\n" + "Content-Type: application/http\r\n" + "\r\n" + "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "Content-Length: 15\r\n" + "Content-Type: application/json; charset=UTF-8\r\n" + "\r\n" + "{\"key\":\"value\"}\r\n" + "--__END_OF_PART__--\r\n" if string(b) != want { t.Errorf("Bytes() = %q; want = %q", string(b), want) } } func TestSendEachWorkerPoolScenarios(t *testing.T) { scenarios := []struct { name string numMessages int allSuccessful bool testNameSuffix string // To make test names more descriptive if needed }{ {numMessages: 5, allSuccessful: true, testNameSuffix: " (5msg < 50workers)"}, {numMessages: 50, allSuccessful: true, testNameSuffix: " (50msg == 50workers)"}, {numMessages: 75, allSuccessful: true, testNameSuffix: " (75msg > 50workers)"}, {numMessages: 75, allSuccessful: false, testNameSuffix: " (75msg > 50workers, with Failures)"}, } for _, s := range scenarios { scenarioName := fmt.Sprintf("NumMessages_%d_AllSuccess_%v%s", s.numMessages, s.allSuccessful, s.testNameSuffix) t.Run(scenarioName, func(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } messages := make([]*Message, s.numMessages) expectedSuccessCount := s.numMessages expectedFailureCount := 0 serverHitCount := 0 mu := &sync.Mutex{} // To protect serverHitCount ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() serverHitCount++ mu.Unlock() var reqBody fcmRequest if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { w.WriteHeader(http.StatusBadRequest) return } var originalIndex int if !s.allSuccessful { // Only parse index if we might fail based on it topicParts := strings.Split(reqBody.Message.Topic, "topic") if len(topicParts) == 2 { fmt.Sscanf(topicParts[1], "%d", &originalIndex) } else { t.Logf("Unexpected topic format: %s", reqBody.Message.Topic) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]string{ "name": fmt.Sprintf("projects/test-project/messages/%s-unexpected", reqBody.Message.Topic), }) return } } if !s.allSuccessful && (originalIndex+1)%3 == 0 { w.WriteHeader(http.StatusInternalServerError) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ "error": map[string]interface{}{ "message": fmt.Sprintf("Simulated server error for original index %d", originalIndex), "status": "INTERNAL", }, }) } else { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{ "name": fmt.Sprintf("projects/test-project/messages/%s-idx%d", reqBody.Message.Topic, originalIndex), }) } })) defer ts.Close() client.fcmEndpoint = ts.URL for i := 0; i < s.numMessages; i++ { messages[i] = &Message{Topic: fmt.Sprintf("topic%d", i)} } if !s.allSuccessful { expectedSuccessCount = 0 expectedFailureCount = 0 for i := 0; i < s.numMessages; i++ { if (i+1)%3 == 0 { expectedFailureCount++ } else { expectedSuccessCount++ } } } br, err := client.SendEach(ctx, messages) if err != nil { t.Fatalf("SendEach() unexpected error: %v", err) } if br.SuccessCount != expectedSuccessCount { t.Errorf("SuccessCount = %d; want = %d", br.SuccessCount, expectedSuccessCount) } if br.FailureCount != expectedFailureCount { t.Errorf("FailureCount = %d; want = %d", br.FailureCount, expectedFailureCount) } if len(br.Responses) != s.numMessages { t.Errorf("len(Responses) = %d; want = %d", len(br.Responses), s.numMessages) } mu.Lock() // Protect serverHitCount read if serverHitCount != s.numMessages { t.Errorf("Server hit count = %d; want = %d", serverHitCount, s.numMessages) } mu.Unlock() for i, resp := range br.Responses { isExpectedToSucceed := s.allSuccessful || (i+1)%3 != 0 if resp.Success != isExpectedToSucceed { t.Errorf("Responses[%d].Success = %v; want = %v", i, resp.Success, isExpectedToSucceed) } if isExpectedToSucceed && resp.MessageID == "" { t.Errorf("Responses[%d].MessageID is empty for a successful message", i) } if !isExpectedToSucceed && resp.Error == nil { t.Errorf("Responses[%d].Error is nil for a failed message", i) } } }) } } func TestSendEachResponseOrderWithConcurrency(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } numMessages := 75 // Ensure this is > new worker count of 50 messages := make([]*Message, numMessages) for i := 0; i < numMessages; i++ { messages[i] = &Message{Token: fmt.Sprintf("token%d", i)} // Using Token for unique identification } serverHitCount := 0 messageIDLog := make(map[string]int) var mu sync.Mutex ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() serverHitCount++ hitOrder := serverHitCount mu.Unlock() var reqBody fcmRequest if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { w.WriteHeader(http.StatusBadRequest) return } messageIdentifier := reqBody.Message.Token mu.Lock() messageIDLog[messageIdentifier] = hitOrder mu.Unlock() w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{ "name": fmt.Sprintf("projects/test-project/messages/msg_for_%s", messageIdentifier), }) })) defer ts.Close() client.fcmEndpoint = ts.URL br, err := client.SendEach(ctx, messages) if err != nil { t.Fatalf("SendEach() unexpected error: %v", err) } if br.SuccessCount != numMessages { t.Errorf("SuccessCount = %d; want = %d", br.SuccessCount, numMessages) } if len(br.Responses) != numMessages { t.Errorf("len(Responses) = %d; want = %d", len(br.Responses), numMessages) } if serverHitCount != numMessages { t.Errorf("Server hit count = %d; want = %d", serverHitCount, numMessages) } for i, resp := range br.Responses { if !resp.Success { t.Errorf("Responses[%d] was not successful: %v", i, resp.Error) continue } expectedMessageIDPart := fmt.Sprintf("msg_for_token%d", i) if !strings.Contains(resp.MessageID, expectedMessageIDPart) { t.Errorf("Responses[%d].MessageID = %q; want to contain %q", i, resp.MessageID, expectedMessageIDPart) } } } func TestSendEachEarlyValidationSkipsSend(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } messagesWithInvalid := []*Message{{Topic: "topic1"}, nil, {Topic: "topic2"}} serverHitCount := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { serverHitCount++ w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{ "name":"projects/test-project/messages/1" }`)) })) defer ts.Close() client.fcmEndpoint = ts.URL br, err := client.SendEach(ctx, messagesWithInvalid) if err == nil { t.Errorf("SendEach() expected error for invalid message, got nil") } if br != nil { t.Errorf("SendEach() expected nil BatchResponse for invalid message, got %v", br) } if serverHitCount != 0 { t.Errorf("Server hit count = %d; want = 0 due to early validation failure", serverHitCount) } messagesWithInvalidFirst := []*Message{ {Topic: "invalid", Condition: "invalid"}, // Invalid: both Topic and Condition {Topic: "topic1"}, } serverHitCount = 0 br, err = client.SendEach(ctx, messagesWithInvalidFirst) if err == nil { t.Errorf("SendEach() expected error for invalid first message, got nil") } if br != nil { t.Errorf("SendEach() expected nil BatchResponse for invalid first message, got %v", br) } if serverHitCount != 0 { t.Errorf("Server hit count = %d; want = 0 for invalid first message", serverHitCount) } messagesWithInvalidLast := []*Message{ {Topic: "topic1"}, // Valid first message {Topic: "topic_last", Token: "token_last"}, // Invalid: cannot have both Topic and Token } serverHitCount = 0 br, err = client.SendEach(ctx, messagesWithInvalidLast) if err == nil { t.Errorf("SendEach() expected error for invalid last message, got nil") } if br != nil { t.Errorf("SendEach() expected nil BatchResponse for invalid last message, got %v", br) } if serverHitCount != 0 { t.Errorf("Server hit count = %d; want = 0 for invalid last message", serverHitCount) } } func TestMultipartEntity(t *testing.T) { entity := &multipartEntity{ parts: []*part{ { method: "POST", url: "http://example1.com", body: map[string]interface{}{"key1": "value"}, }, { method: "POST", url: "http://example2.com", body: map[string]interface{}{"key2": "value"}, headers: map[string]string{"Custom-Header": "custom-value"}, }, }, } mime := entity.Mime() if mime != wantMime { t.Errorf("Mime() = %q; want = %q", mime, wantMime) } b, err := entity.Bytes() if err != nil { t.Fatal(err) } want := "--__END_OF_PART__\r\n" + "Content-Id: 1\r\n" + "Content-Length: 122\r\n" + "Content-Transfer-Encoding: binary\r\n" + "Content-Type: application/http\r\n" + "\r\n" + "POST / HTTP/1.1\r\n" + "Host: example1.com\r\n" + "Content-Length: 16\r\n" + "Content-Type: application/json; charset=UTF-8\r\n" + "\r\n" + "{\"key1\":\"value\"}\r\n" + "--__END_OF_PART__\r\n" + "Content-Id: 2\r\n" + "Content-Length: 151\r\n" + "Content-Transfer-Encoding: binary\r\n" + "Content-Type: application/http\r\n" + "\r\n" + "POST / HTTP/1.1\r\n" + "Host: example2.com\r\n" + "Content-Length: 16\r\n" + "Content-Type: application/json; charset=UTF-8\r\n" + "Custom-Header: custom-value\r\n" + "\r\n" + "{\"key2\":\"value\"}\r\n" + "--__END_OF_PART__--\r\n" if string(b) != want { t.Errorf("multipartPayload() = %q; want = %q", string(b), want) } } func TestMultipartEntityError(t *testing.T) { entity := &multipartEntity{ parts: []*part{{ method: "POST", url: "http://example.com", body: func() {}, }}, } b, err := entity.Bytes() if b != nil || err == nil { t.Errorf("Bytes() = (%v, %v); want = (nil, error)", b, nil) } } func TestSendEachEmptyArray(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "messages must not be nil or empty" br, err := client.SendEach(ctx, nil) if err == nil || err.Error() != want { t.Errorf("SendEach(nil) = (%v, %v); want = (nil, %q)", br, err, want) } br, err = client.SendEach(ctx, []*Message{}) if err == nil || err.Error() != want { t.Errorf("SendEach(nil) = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEachTooManyMessages(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } var messages []*Message for i := 0; i < 501; i++ { messages = append(messages, &Message{Topic: "test-topic"}) } want := "messages must not contain more than 500 elements" br, err := client.SendEach(ctx, messages) if err == nil || err.Error() != want { t.Errorf("SendEach() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEachInvalidMessage(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "invalid message at index 0: message must not be nil" br, err := client.SendEach(ctx, []*Message{nil}) if err == nil || err.Error() != want { t.Errorf("SendEach() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEach(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ := ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") for idx, testMessage := range testMessages { if strings.Contains(string(req), testMessage.Topic) { w.Write([]byte("{ \"name\":\"" + testSuccessResponse[idx].Name + "\" }")) } } })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL br, err := client.SendEach(ctx, testMessages) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponseForSendEach(br, false); err != nil { t.Errorf("SendEach() = %v", err) } } func TestSendEachDryRun(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ := ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") for idx, testMessage := range testMessages { if strings.Contains(string(req), testMessage.Topic) { w.Write([]byte("{ \"name\":\"" + testSuccessResponse[idx].Name + "\" }")) } } })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL br, err := client.SendEachDryRun(ctx, testMessages) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponseForSendEach(br, true); err != nil { t.Errorf("SendEach() = %v", err) } } func TestSendEachPartialFailure(t *testing.T) { success := []fcmResponse{ {Name: "projects/test-project/messages/1"}, } var failures []string ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } for idx, tc := range httpErrors { failures = []string{tc.resp} serverHitCount := 0 var mu sync.Mutex ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() serverHitCount++ mu.Unlock() reqBody, _ := ioutil.ReadAll(r.Body) var msgIn fcmRequest json.Unmarshal(reqBody, &msgIn) if msgIn.Message.Topic == testMessages[0].Topic { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{ "name":"` + success[0].Name + `" }`)) } else if msgIn.Message.Topic == testMessages[1].Topic { w.WriteHeader(http.StatusInternalServerError) w.Header().Set("Content-Type", "application/json") w.Write([]byte(failures[0])) } else { w.WriteHeader(http.StatusBadRequest) w.Write([]byte(`{"error":"unknown topic"}`)) } })) defer ts.Close() client.fcmEndpoint = ts.URL br, err := client.SendEach(ctx, testMessages) if err != nil { t.Fatalf("[%d] SendEach() unexpected error: %v", idx, err) } mu.Lock() if serverHitCount != len(testMessages) { t.Errorf("[%d] Server hit count = %d; want = %d", idx, serverHitCount, len(testMessages)) } mu.Unlock() if err := checkPartialErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendEach() = %v", idx, err) } } } func TestSendEachTotalFailure(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmClient.httpClient.RetryConfig = nil for idx, tc := range httpErrors { serverHitCount := 0 var mu sync.Mutex ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mu.Lock() serverHitCount++ mu.Unlock() w.WriteHeader(http.StatusInternalServerError) w.Header().Set("Content-Type", "application/json") w.Write([]byte(tc.resp)) })) defer ts.Close() client.fcmEndpoint = ts.URL br, err := client.SendEach(ctx, testMessages) if err != nil { t.Fatalf("[%d] SendEach() unexpected error: %v", idx, err) } mu.Lock() if serverHitCount != len(testMessages) { t.Errorf("[%d] Server hit count = %d; want = %d", idx, serverHitCount, len(testMessages)) } mu.Unlock() if err := checkTotalErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendEach() = %v", idx, err) } } } func TestSendEachForMulticastNil(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "message must not be nil" br, err := client.SendEachForMulticast(ctx, nil) if err == nil || err.Error() != want { t.Errorf("SendEachForMulticast(nil) = (%v, %v); want = (nil, %q)", br, err, want) } br, err = client.SendEachForMulticastDryRun(ctx, nil) if err == nil || err.Error() != want { t.Errorf("SendEachForMulticast(nil) = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEachForMulticastEmptyArray(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "tokens must not be nil or empty" mm := &MulticastMessage{} br, err := client.SendEachForMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendEachForMulticast(Tokens: nil) = (%v, %v); want = (nil, %q)", br, err, want) } var tokens []string mm = &MulticastMessage{ Tokens: tokens, } br, err = client.SendEachForMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendEachForMulticast(Tokens: []) = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEachForMulticastTooManyTokens(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } var tokens []string for i := 0; i < 501; i++ { tokens = append(tokens, fmt.Sprintf("token%d", i)) } want := "tokens must not contain more than 500 elements" mm := &MulticastMessage{Tokens: tokens} br, err := client.SendEachForMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendEachForMulticast() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEachForMulticastInvalidMessage(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "invalid message at index 0: priority must be 'normal' or 'high'" mm := &MulticastMessage{ Tokens: []string{"token1"}, Android: &AndroidConfig{Priority: "invalid"}, } br, err := client.SendEachForMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendEachForMulticast() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendEachForMulticast(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ := ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") for idx, token := range testMulticastMessage.Tokens { if strings.Contains(string(req), token) { w.Write([]byte("{ \"name\":\"" + testSuccessResponse[idx].Name + "\" }")) } } })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL br, err := client.SendEachForMulticast(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponseForSendEach(br, false); err != nil { t.Errorf("SendEachForMulticast() = %v", err) } } func TestSendEachForMulticastWithCustomEndpoint(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ := ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") for idx, token := range testMulticastMessage.Tokens { if strings.Contains(string(req), token) { w.Write([]byte("{ \"name\":\"" + testSuccessResponse[idx].Name + "\" }")) } } })) defer ts.Close() ctx := context.Background() conf := *testMessagingConfig optEndpoint := option.WithEndpoint(ts.URL) conf.Opts = append(conf.Opts, optEndpoint) client, err := NewClient(ctx, &conf) if err != nil { t.Fatal(err) } if ts.URL != client.fcmEndpoint { t.Errorf("client.fcmEndpoint = %q; want = %q", client.fcmEndpoint, ts.URL) } br, err := client.SendEachForMulticast(ctx, testMulticastMessage) if err := checkSuccessfulBatchResponseForSendEach(br, false); err != nil { t.Errorf("SendEachForMulticast() = %v", err) } } func TestSendEachForMulticastDryRun(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ := ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") for idx, token := range testMulticastMessage.Tokens { if strings.Contains(string(req), token) { w.Write([]byte("{ \"name\":\"" + testSuccessResponse[idx].Name + "\" }")) } } })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL br, err := client.SendEachForMulticastDryRun(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponseForSendEach(br, true); err != nil { t.Errorf("SendEachForMulticastDryRun() = %v", err) } } func TestSendEachForMulticastPartialFailure(t *testing.T) { success := []fcmResponse{ {Name: "projects/test-project/messages/1"}, } var failures []string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ := ioutil.ReadAll(r.Body) for idx, token := range testMulticastMessage.Tokens { if strings.Contains(string(req), token) { if idx%2 == 0 { w.Header().Set("Content-Type", wantMime) w.Write([]byte("{ \"name\":\"" + success[0].Name + "\" }")) } else { w.WriteHeader(http.StatusInternalServerError) w.Header().Set("Content-Type", wantMime) w.Write([]byte(failures[0])) } } } })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL for idx, tc := range httpErrors { failures = []string{tc.resp} br, err := client.SendEachForMulticast(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkPartialErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendEachForMulticast() = %v", idx, err) } } } func TestSendAllEmptyArray(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "messages must not be nil or empty" br, err := client.SendAll(ctx, nil) if err == nil || err.Error() != want { t.Errorf("SendAll(nil) = (%v, %v); want = (nil, %q)", br, err, want) } br, err = client.SendAll(ctx, []*Message{}) if err == nil || err.Error() != want { t.Errorf("SendAll(nil) = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendAllTooManyMessages(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } var messages []*Message for i := 0; i < 501; i++ { messages = append(messages, &Message{Topic: "test-topic"}) } want := "messages must not contain more than 500 elements" br, err := client.SendAll(ctx, messages) if err == nil || err.Error() != want { t.Errorf("SendAll() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendAllInvalidMessage(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "invalid message at index 0: message must not be nil" br, err := client.SendAll(ctx, []*Message{nil}) if err == nil || err.Error() != want { t.Errorf("SendAll() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendAll(t *testing.T) { resp, err := createMultipartResponse(testSuccessResponse, nil) if err != nil { t.Fatal(err) } var req []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL br, err := client.SendAll(ctx, testMessages) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponse(br, req, false); err != nil { t.Errorf("SendAll() = %v", err) } } func TestSendAllDryRun(t *testing.T) { resp, err := createMultipartResponse(testSuccessResponse, nil) if err != nil { t.Fatal(err) } var req []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL br, err := client.SendAllDryRun(ctx, testMessages) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponse(br, req, true); err != nil { t.Errorf("SendAll() = %v", err) } } func TestSendAllPartialFailure(t *testing.T) { success := []fcmResponse{ {Name: "projects/test-project/messages/1"}, } var req, resp []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL for idx, tc := range httpErrors { failures := []string{tc.resp} resp, err = createMultipartResponse(success, failures) if err != nil { t.Fatal(err) } br, err := client.SendAll(ctx, testMessages) if err != nil { t.Fatal(err) } if err := checkPartialErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendAll() = %v", idx, err) } if err := checkMultipartRequest(req, false); err != nil { t.Errorf("[%d] MultipartRequest: %v", idx, err) } } } func TestSendAllTotalFailure(t *testing.T) { var resp string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Header().Set("Content-Type", "application/json") w.Write([]byte(resp)) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL client.fcmClient.httpClient.RetryConfig = nil for _, tc := range httpErrors { resp = tc.resp br, err := client.SendAll(ctx, []*Message{{Topic: "topic"}}) if err == nil || err.Error() != tc.want || !tc.check(err) { t.Errorf("SendAll() = (%v, %v); want = (nil, %q)", br, err, tc.want) } } } func TestSendAllNonMultipartResponse(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte("{}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL if _, err = client.SendAll(ctx, testMessages); err == nil { t.Fatal("SendAll() = nil; want = error") } } func TestSendAllMalformedContentType(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "invalid content-type") w.Write([]byte("{}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL if _, err = client.SendAll(ctx, testMessages); err == nil { t.Fatal("SendAll() = nil; want = error") } } func TestSendAllMalformedMultipartResponse(t *testing.T) { malformedResp := "--__END_OF_PART__\r\n" + "Content-Id: 1\r\n" + "Content-Type: application/http\r\n" + "\r\n" + "Malformed Response\r\n" + "--__END_OF_PART__--\r\n" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", wantMime) w.Write([]byte(malformedResp)) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL if _, err = client.SendAll(ctx, testMessages); err == nil { t.Fatal("SendAll() = nil; want = error") } } func TestSendMulticastNil(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "message must not be nil" br, err := client.SendMulticast(ctx, nil) if err == nil || err.Error() != want { t.Errorf("SendMulticast(nil) = (%v, %v); want = (nil, %q)", br, err, want) } br, err = client.SendMulticastDryRun(ctx, nil) if err == nil || err.Error() != want { t.Errorf("SendMulticast(nil) = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendMulticastEmptyArray(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "tokens must not be nil or empty" mm := &MulticastMessage{} br, err := client.SendMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendMulticast(Tokens: nil) = (%v, %v); want = (nil, %q)", br, err, want) } var tokens []string mm = &MulticastMessage{ Tokens: tokens, } br, err = client.SendMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendMulticast(Tokens: []) = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendMulticastTooManyTokens(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } var tokens []string for i := 0; i < 501; i++ { tokens = append(tokens, fmt.Sprintf("token%d", i)) } want := "tokens must not contain more than 500 elements" mm := &MulticastMessage{Tokens: tokens} br, err := client.SendMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendMulticast() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendMulticastInvalidMessage(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } want := "invalid message at index 0: priority must be 'normal' or 'high'" mm := &MulticastMessage{ Tokens: []string{"token1"}, Android: &AndroidConfig{Priority: "invalid"}, } br, err := client.SendMulticast(ctx, mm) if err == nil || err.Error() != want { t.Errorf("SendMulticast() = (%v, %v); want = (nil, %q)", br, err, want) } } func TestSendMulticast(t *testing.T) { resp, err := createMultipartResponse(testSuccessResponse, nil) if err != nil { t.Fatal(err) } var req []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL br, err := client.SendMulticast(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponse(br, req, false); err != nil { t.Errorf("SendMulticast() = %v", err) } } func TestSendMulticastWithCustomEndpoint(t *testing.T) { resp, err := createMultipartResponse(testSuccessResponse, nil) if err != nil { t.Fatal(err) } var req []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() conf := *testMessagingConfig customBatchEndpoint := fmt.Sprintf("%s/v1", ts.URL) optEndpoint := option.WithEndpoint(customBatchEndpoint) conf.Opts = append(conf.Opts, optEndpoint) client, err := NewClient(ctx, &conf) if err != nil { t.Fatal(err) } if customBatchEndpoint != client.batchEndpoint { t.Errorf("client.batchEndpoint = %q; want = %q", client.batchEndpoint, customBatchEndpoint) } br, err := client.SendMulticast(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponse(br, req, false); err != nil { t.Errorf("SendMulticast() = %v", err) } } func TestSendMulticastDryRun(t *testing.T) { resp, err := createMultipartResponse(testSuccessResponse, nil) if err != nil { t.Fatal(err) } var req []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { req, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL br, err := client.SendMulticastDryRun(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkSuccessfulBatchResponse(br, req, true); err != nil { t.Errorf("SendMulticastDryRun() = %v", err) } } func TestSendMulticastPartialFailure(t *testing.T) { success := []fcmResponse{testSuccessResponse[0]} var resp []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", wantMime) w.Write(resp) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.batchEndpoint = ts.URL for idx, tc := range httpErrors { failures := []string{tc.resp} resp, err = createMultipartResponse(success, failures) if err != nil { t.Fatal(err) } br, err := client.SendMulticast(ctx, testMulticastMessage) if err != nil { t.Fatal(err) } if err := checkPartialErrorBatchResponse(br, tc); err != nil { t.Errorf("[%d] SendMulticast() = %v", idx, err) } } } func checkSuccessfulBatchResponseForSendEach(br *BatchResponse, dryRun bool) error { if br.SuccessCount != 2 { return fmt.Errorf("SuccessCount = %d; want = 2", br.SuccessCount) } if br.FailureCount != 0 { return fmt.Errorf("FailureCount = %d; want = 0", br.FailureCount) } if len(br.Responses) != 2 { return fmt.Errorf("len(Responses) = %d; want = 2", len(br.Responses)) } for idx, r := range br.Responses { if err := checkSuccessfulSendResponse(r, testSuccessResponse[idx].Name); err != nil { return fmt.Errorf("Responses[%d]: %v", idx, err) } } return nil } func checkSuccessfulBatchResponse(br *BatchResponse, req []byte, dryRun bool) error { if br.SuccessCount != 2 { return fmt.Errorf("SuccessCount = %d; want = 2", br.SuccessCount) } if br.FailureCount != 0 { return fmt.Errorf("FailureCount = %d; want = 0", br.FailureCount) } if len(br.Responses) != 2 { return fmt.Errorf("len(Responses) = %d; want = 2", len(br.Responses)) } for idx, r := range br.Responses { if err := checkSuccessfulSendResponse(r, testSuccessResponse[idx].Name); err != nil { return fmt.Errorf("Responses[%d]: %v", idx, err) } } if err := checkMultipartRequest(req, dryRun); err != nil { return fmt.Errorf("MultipartRequest: %v", err) } return nil } func checkTotalErrorBatchResponse(br *BatchResponse, tc struct { resp, want string check func(error) bool }) error { if br.SuccessCount != 0 { return fmt.Errorf("SuccessCount = %d; want = 0", br.SuccessCount) } if br.FailureCount != 2 { return fmt.Errorf("FailureCount = %d; want = 2", br.FailureCount) } if len(br.Responses) != 2 { return fmt.Errorf("len(Responses) = %d; want = 2", len(br.Responses)) } for i, r := range br.Responses { if r.Success { return fmt.Errorf("Responses[%d]: Success = true; want = false", i) } if r.Error == nil || r.Error.Error() != tc.want || !tc.check(r.Error) { return fmt.Errorf("Responses[%d]: Error = %v; want = %q", i, r.Error, tc.want) } if r.MessageID != "" { return fmt.Errorf("Responses[%d]: MessageID = %q; want = %q", i, r.MessageID, "") } } return nil } func checkPartialErrorBatchResponse(br *BatchResponse, tc struct { resp, want string check func(error) bool }) error { if br.SuccessCount != 1 { return fmt.Errorf("SuccessCount = %d; want = 1", br.SuccessCount) } if br.FailureCount != 1 { return fmt.Errorf("FailureCount = %d; want = 1", br.FailureCount) } if len(br.Responses) != 2 { return fmt.Errorf("len(Responses) = %d; want = 2", len(br.Responses)) } if err := checkSuccessfulSendResponse(br.Responses[0], testSuccessResponse[0].Name); err != nil { return fmt.Errorf("Responses[0]: %v", err) } r := br.Responses[1] if r.Success { return fmt.Errorf("Responses[1]: Success = true; want = false") } if r.Error == nil || r.Error.Error() != tc.want || !tc.check(r.Error) { return fmt.Errorf("Responses[1]: Error = %v; want = %q", r.Error, tc.want) } if r.MessageID != "" { return fmt.Errorf("Responses[1]: MessageID = %q; want = %q", r.MessageID, "") } return nil } func checkSuccessfulSendResponse(r *SendResponse, wantID string) error { if !r.Success { return fmt.Errorf("Success = false; want = true") } if r.Error != nil { return fmt.Errorf("Error = %v; want = nil", r.Error) } if r.MessageID != wantID { return fmt.Errorf("MessageID = %q; want = %q", r.MessageID, wantID) } return nil } func checkMultipartRequest(b []byte, dryRun bool) error { reader := multipart.NewReader(bytes.NewBuffer(b), multipartBoundary) count := 0 for { part, err := reader.NextPart() if err == io.EOF { break } else if err != nil { return err } if err := checkRequestPart(part, dryRun); err != nil { return fmt.Errorf("[%d] %v", count, err) } count++ } if count != 2 { return fmt.Errorf("PartsCount = %d; want = 2", count) } return nil } func checkRequestPart(part *multipart.Part, dryRun bool) error { r, err := http.ReadRequest(bufio.NewReader(part)) if err != nil { return err } if r.Method != http.MethodPost { return fmt.Errorf("Method = %q; want = %q", r.Method, http.MethodPost) } if r.RequestURI != wantSendURL { return fmt.Errorf("URL = %q; want = %q", r.RequestURI, wantSendURL) } if h := r.Header.Get("X-GOOG-API-FORMAT-VERSION"); h != "2" { return fmt.Errorf("X-GOOG-API-FORMAT-VERSION = %q; want = %q", h, "2") } clientVersion := "fire-admin-go/" + testMessagingConfig.Version if h := r.Header.Get("X-FIREBASE-CLIENT"); h != clientVersion { return fmt.Errorf("X-FIREBASE-CLIENT = %q; want = %q", h, clientVersion) } b, _ := ioutil.ReadAll(r.Body) var parsed map[string]interface{} if err := json.Unmarshal(b, &parsed); err != nil { return err } if _, ok := parsed["message"]; !ok { return fmt.Errorf("Invalid message body = %v", parsed) } validate, ok := parsed["validate_only"] if dryRun { if !ok || validate != true { return fmt.Errorf("ValidateOnly = %v; want = true", validate) } } else if ok { return fmt.Errorf("ValidateOnly = %v; want none", validate) } return nil } func createMultipartResponse(success []fcmResponse, failure []string) ([]byte, error) { var buffer bytes.Buffer writer := multipart.NewWriter(&buffer) writer.SetBoundary(multipartBoundary) for idx, data := range success { b, err := json.Marshal(data) if err != nil { return nil, err } var partBuffer bytes.Buffer partBuffer.WriteString("HTTP/1.1 200 OK\r\n") partBuffer.WriteString("Content-Type: application/json\r\n\r\n") partBuffer.Write(b) if err := writeResponsePart(writer, partBuffer.Bytes(), idx); err != nil { return nil, err } } for idx, data := range failure { var partBuffer bytes.Buffer partBuffer.WriteString("HTTP/1.1 500 Internal Server Error\r\n") partBuffer.WriteString("Content-Type: application/json\r\n\r\n") partBuffer.WriteString(data) if err := writeResponsePart(writer, partBuffer.Bytes(), idx+len(success)); err != nil { return nil, err } } writer.Close() return buffer.Bytes(), nil } func writeResponsePart(writer *multipart.Writer, data []byte, idx int) error { header := make(textproto.MIMEHeader) header.Add("Content-Type", "application/http") header.Add("Content-Id", fmt.Sprintf("%d", idx+1)) part, err := writer.CreatePart(header) if err != nil { return err } _, err = part.Write(data) return err } golang-google-firebase-go-4.18.0/messaging/messaging_test.go000066400000000000000000001120511505612111400240220ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "context" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "reflect" "testing" "time" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" ) const testMessageID = "projects/test-project/messages/msg_id" var ( testMessagingConfig = &internal.MessagingConfig{ ProjectID: "test-project", Opts: []option.ClientOption{ option.WithTokenSource(&internal.MockTokenSource{AccessToken: "test-token"}), }, Version: "test-version", } ttlWithNanos = time.Duration(1500) * time.Millisecond ttl = time.Duration(10) * time.Second invalidTTL = time.Duration(-10) * time.Second badge = 42 badgeZero = 0 timestampMillis = int64(12345) timestamp = time.Unix(0, 1546304523123*1000000).UTC() ) var validMessages = []struct { name string req *Message want map[string]interface{} }{ { name: "TokenOnly", req: &Message{Token: "test-token"}, want: map[string]interface{}{"token": "test-token"}, }, { name: "TopicOnly", req: &Message{Topic: "test-topic"}, want: map[string]interface{}{"topic": "test-topic"}, }, { name: "PrefixedTopicOnly", req: &Message{Topic: "/topics/test-topic"}, want: map[string]interface{}{"topic": "test-topic"}, }, { name: "ConditionOnly", req: &Message{Condition: "test-condition"}, want: map[string]interface{}{"condition": "test-condition"}, }, { name: "DataMessage", req: &Message{ Data: map[string]string{ "k1": "v1", "k2": "v2", }, FCMOptions: &FCMOptions{ AnalyticsLabel: "Analytics", }, Topic: "test-topic", }, want: map[string]interface{}{ "data": map[string]interface{}{ "k1": "v1", "k2": "v2", }, "fcm_options": map[string]interface{}{ "analytics_label": "Analytics", }, "topic": "test-topic", }, }, { name: "NotificationMessage", req: &Message{ Notification: &Notification{ Title: "t", Body: "b", ImageURL: "http://image.jpg", }, Topic: "test-topic", }, want: map[string]interface{}{ "notification": map[string]interface{}{ "title": "t", "body": "b", "image": "http://image.jpg", }, "topic": "test-topic", }, }, { name: "AndroidDataMessage", req: &Message{ Android: &AndroidConfig{ CollapseKey: "ck", Data: map[string]string{ "k1": "v1", "k2": "v2", }, Priority: "normal", TTL: &ttl, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "collapse_key": "ck", "data": map[string]interface{}{ "k1": "v1", "k2": "v2", }, "priority": "normal", "ttl": "10s", }, "topic": "test-topic", }, }, { name: "AndroidDataMessage", req: &Message{ Android: &AndroidConfig{ DirectBootOK: true, CollapseKey: "ck", Data: map[string]string{ "k1": "v1", "k2": "v2", }, Priority: "normal", TTL: &ttl, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "direct_boot_ok": true, "collapse_key": "ck", "data": map[string]interface{}{ "k1": "v1", "k2": "v2", }, "priority": "normal", "ttl": "10s", }, "topic": "test-topic", }, }, { name: "AndroidNotificationMessage", req: &Message{ Android: &AndroidConfig{ RestrictedPackageName: "rpn", Notification: &AndroidNotification{ Title: "t", Body: "b", Color: "#112233", Sound: "s", TitleLocKey: "tlk", TitleLocArgs: []string{"t1", "t2"}, BodyLocKey: "blk", BodyLocArgs: []string{"b1", "b2"}, ChannelID: "channel", ImageURL: "http://image.jpg", Ticker: "tkr", Sticky: true, EventTimestamp: ×tamp, LocalOnly: true, Priority: PriorityMax, VibrateTimingMillis: []int64{100, 50, 100}, DefaultVibrateTimings: true, DefaultSound: true, LightSettings: &LightSettings{ Color: "#33669966", LightOnDurationMillis: 100, LightOffDurationMillis: 50, }, Visibility: VisibilityPrivate, DefaultLightSettings: true, }, TTL: &ttlWithNanos, FCMOptions: &AndroidFCMOptions{ AnalyticsLabel: "Analytics", }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "restricted_package_name": "rpn", "notification": map[string]interface{}{ "title": "t", "body": "b", "color": "#112233", "sound": "s", "title_loc_key": "tlk", "title_loc_args": []interface{}{"t1", "t2"}, "body_loc_key": "blk", "body_loc_args": []interface{}{"b1", "b2"}, "channel_id": "channel", "image": "http://image.jpg", "ticker": "tkr", "sticky": true, "event_time": "2019-01-01T01:02:03.123000000Z", "local_only": true, "notification_priority": "PRIORITY_MAX", "vibrate_timings": []interface{}{"0.100000000s", "0.050000000s", "0.100000000s"}, "default_vibrate_timings": true, "default_sound": true, "light_settings": map[string]interface{}{ "color": map[string]interface{}{ "red": float64(0.2), "green": float64(0.4), "blue": float64(0.6), "alpha": float64(0.4), }, "light_on_duration": "0.100000000s", "light_off_duration": "0.050000000s", }, "visibility": "PRIVATE", "default_light_settings": true, }, "ttl": "1.500000000s", "fcm_options": map[string]interface{}{ "analytics_label": "Analytics", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationLightSettings", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ LightSettings: &LightSettings{ Color: "#336699", LightOnDurationMillis: 100, LightOffDurationMillis: 50, }, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "light_settings": map[string]interface{}{ "color": map[string]interface{}{ "red": float64(0.2), "green": float64(0.4), "blue": float64(0.6), "alpha": float64(1.0), }, "light_on_duration": "0.100000000s", "light_off_duration": "0.050000000s", }, }, }, "topic": "test-topic", }, }, { name: "AndroidNoTTL", req: &Message{ Android: &AndroidConfig{ Priority: "high", }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "priority": "high", }, "topic": "test-topic", }, }, { name: "WebpushMessage", req: &Message{ Webpush: &WebpushConfig{ Headers: map[string]string{ "h1": "v1", "h2": "v2", }, Data: map[string]string{ "k1": "v1", "k2": "v2", }, Notification: &WebpushNotification{ Title: "title", Body: "body", Icon: "icon", Actions: []*WebpushNotificationAction{ { Action: "a1", Title: "a1-title", }, { Action: "a2", Title: "a2-title", Icon: "a2-icon", }, }, Badge: "badge", Data: "data", Image: "image", Language: "lang", Renotify: true, RequireInteraction: true, Silent: true, Tag: "tag", TimestampMillis: ×tampMillis, Vibrate: []int{100, 200, 100}, CustomData: map[string]interface{}{"k1": "v1", "k2": "v2"}, }, FCMOptions: &WebpushFCMOptions{ Link: "https://link.com", }, }, Topic: "test-topic", }, want: map[string]interface{}{ "webpush": map[string]interface{}{ "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, "data": map[string]interface{}{"k1": "v1", "k2": "v2"}, "notification": map[string]interface{}{ "title": "title", "body": "body", "icon": "icon", "actions": []interface{}{ map[string]interface{}{"action": "a1", "title": "a1-title"}, map[string]interface{}{"action": "a2", "title": "a2-title", "icon": "a2-icon"}, }, "badge": "badge", "data": "data", "image": "image", "lang": "lang", "renotify": true, "requireInteraction": true, "silent": true, "tag": "tag", "timestamp": float64(12345), "vibrate": []interface{}{float64(100), float64(200), float64(100)}, "k1": "v1", "k2": "v2", }, "fcm_options": map[string]interface{}{ "link": "https://link.com", }, }, "topic": "test-topic", }, }, { name: "APNSHeadersOnly", req: &Message{ APNS: &APNSConfig{ Headers: map[string]string{ "h1": "v1", "h2": "v2", }, }, Topic: "test-topic", }, want: map[string]interface{}{ "apns": map[string]interface{}{ "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, }, "topic": "test-topic", }, }, { name: "APNSAlertString", req: &Message{ APNS: &APNSConfig{ Headers: map[string]string{ "h1": "v1", "h2": "v2", }, Payload: &APNSPayload{ Aps: &Aps{ AlertString: "a", Badge: &badge, Category: "c", Sound: "s", ThreadID: "t", ContentAvailable: true, MutableContent: true, }, CustomData: map[string]interface{}{ "k1": "v1", "k2": true, }, }, FCMOptions: &APNSFCMOptions{ AnalyticsLabel: "Analytics", ImageURL: "http://image.jpg", }, }, Topic: "test-topic", }, want: map[string]interface{}{ "apns": map[string]interface{}{ "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, "payload": map[string]interface{}{ "aps": map[string]interface{}{ "alert": "a", "badge": float64(badge), "category": "c", "sound": "s", "thread-id": "t", "content-available": float64(1), "mutable-content": float64(1), }, "k1": "v1", "k2": true, }, "fcm_options": map[string]interface{}{ "analytics_label": "Analytics", "image": "http://image.jpg", }, }, "topic": "test-topic", }, }, { name: "APNSAlertCrticalSound", req: &Message{ APNS: &APNSConfig{ Headers: map[string]string{ "h1": "v1", "h2": "v2", }, Payload: &APNSPayload{ Aps: &Aps{ AlertString: "a", Badge: &badge, Category: "c", CriticalSound: &CriticalSound{ Critical: true, Name: "n", Volume: 0.7, }, ThreadID: "t", ContentAvailable: true, MutableContent: true, }, CustomData: map[string]interface{}{ "k1": "v1", "k2": true, }, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "apns": map[string]interface{}{ "headers": map[string]interface{}{"h1": "v1", "h2": "v2"}, "payload": map[string]interface{}{ "aps": map[string]interface{}{ "alert": "a", "badge": float64(badge), "category": "c", "sound": map[string]interface{}{ "critical": float64(1), "name": "n", "volume": float64(0.7), }, "thread-id": "t", "content-available": float64(1), "mutable-content": float64(1), }, "k1": "v1", "k2": true, }, }, "topic": "test-topic", }, }, { name: "APNSBadgeZero", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Badge: &badgeZero, Category: "c", Sound: "s", ThreadID: "t", ContentAvailable: true, MutableContent: true, CustomData: map[string]interface{}{"k1": "v1", "k2": float64(1)}, }, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "apns": map[string]interface{}{ "payload": map[string]interface{}{ "aps": map[string]interface{}{ "badge": float64(badgeZero), "category": "c", "sound": "s", "thread-id": "t", "content-available": float64(1), "mutable-content": float64(1), "k1": "v1", "k2": float64(1), }, }, }, "topic": "test-topic", }, }, { name: "APNSAlertObject", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Alert: &ApsAlert{ Title: "t", SubTitle: "st", Body: "b", TitleLocKey: "tlk", TitleLocArgs: []string{"t1", "t2"}, SubTitleLocKey: "stlk", SubTitleLocArgs: []string{"t1", "t2"}, LocKey: "blk", LocArgs: []string{"b1", "b2"}, ActionLocKey: "alk", LaunchImage: "li", }, }, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "apns": map[string]interface{}{ "payload": map[string]interface{}{ "aps": map[string]interface{}{ "alert": map[string]interface{}{ "title": "t", "subtitle": "st", "body": "b", "title-loc-key": "tlk", "title-loc-args": []interface{}{"t1", "t2"}, "subtitle-loc-key": "stlk", "subtitle-loc-args": []interface{}{"t1", "t2"}, "loc-key": "blk", "loc-args": []interface{}{"b1", "b2"}, "action-loc-key": "alk", "launch-image": "li", }, }, }, }, "topic": "test-topic", }, }, { name: "APNSLiveActivity", req: &Message{ Token: "test-token", APNS: &APNSConfig{ LiveActivityToken: "live-activity-token", }, }, want: map[string]interface{}{ "token": "test-token", "apns": map[string]interface{}{ "live_activity_token": "live-activity-token", }, }, }, { name: "AndroidNotificationPriorityMin", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Priority: PriorityMin, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "notification_priority": "PRIORITY_MIN", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationPriorityLow", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Priority: PriorityLow, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "notification_priority": "PRIORITY_LOW", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationPriorityDefault", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Priority: PriorityDefault, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "notification_priority": "PRIORITY_DEFAULT", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationPriorityHigh", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Priority: PriorityHigh, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "notification_priority": "PRIORITY_HIGH", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationPriorityMax", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Priority: PriorityMax, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "notification_priority": "PRIORITY_MAX", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationProxyAllow", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Proxy: ProxyAllow, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "proxy": "ALLOW", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationProxyDeny", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Proxy: ProxyDeny, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "proxy": "DENY", }, }, "topic": "test-topic", }, }, { name: "AndroidNotificationProxyIfPriorityLowered", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Proxy: ProxyIfPriorityLowered, }, }, Topic: "test-topic", }, want: map[string]interface{}{ "android": map[string]interface{}{ "notification": map[string]interface{}{ "proxy": "IF_PRIORITY_LOWERED", }, }, "topic": "test-topic", }, }, } var invalidMessages = []struct { name string req *Message want string }{ { name: "NilMessage", req: nil, want: "message must not be nil", }, { name: "NoTargets", req: &Message{}, want: "exactly one of token, topic or condition must be specified", }, { name: "MultipleTargets", req: &Message{ Token: "token", Topic: "topic", }, want: "exactly one of token, topic or condition must be specified", }, { name: "InvalidPrefixedTopicName", req: &Message{ Topic: "/topics/", }, want: "malformed topic name", }, { name: "InvalidTopicName", req: &Message{ Topic: "foo*bar", }, want: "malformed topic name", }, { name: "InvalidNotificationImage", req: &Message{ Notification: &Notification{ ImageURL: "image.jpg", }, Topic: "topic", }, want: `invalid image URL: "image.jpg"`, }, { name: "InvalidAndroidTTL", req: &Message{ Android: &AndroidConfig{ TTL: &invalidTTL, }, Topic: "topic", }, want: "ttl duration must not be negative", }, { name: "InvalidAndroidPriority", req: &Message{ Android: &AndroidConfig{ Priority: "not normal", }, Topic: "topic", }, want: "priority must be 'normal' or 'high'", }, { name: "InvalidAndroidColor1", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Color: "112233", }, }, Topic: "topic", }, want: "color must be in the #RRGGBB form", }, { name: "InvalidAndroidColor2", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ Color: "#112233X", }, }, Topic: "topic", }, want: "color must be in the #RRGGBB form", }, { name: "InvalidAndroidTitleLocArgs", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ TitleLocArgs: []string{"a1"}, }, }, Topic: "topic", }, want: "titleLocKey is required when specifying titleLocArgs", }, { name: "InvalidAndroidBodyLocArgs", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ BodyLocArgs: []string{"a1"}, }, }, Topic: "topic", }, want: "bodyLocKey is required when specifying bodyLocArgs", }, { name: "InvalidAndroidImage", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ ImageURL: "image.jpg", }, }, Topic: "topic", }, want: `invalid image URL: "image.jpg"`, }, { name: "InvalidLightSettingsColor1", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ LightSettings: &LightSettings{ Color: "112233", }, }, }, Topic: "topic", }, want: "color must be in #RRGGBB or #RRGGBBAA form", }, { name: "InvalidLightSettingsColor2", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ LightSettings: &LightSettings{ Color: "#11223X", }, }, }, Topic: "topic", }, want: "color must be in #RRGGBB or #RRGGBBAA form", }, { name: "InvalidLightSettingsColor3", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ LightSettings: &LightSettings{ Color: "#112234X", }, }, }, Topic: "topic", }, want: "color must be in #RRGGBB or #RRGGBBAA form", }, { name: "InvalidLightSettingsOnDuration", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ LightSettings: &LightSettings{ Color: "#112233", LightOnDurationMillis: -1, }, }, }, Topic: "topic", }, want: "lightOnDuration must not be negative", }, { name: "InvalidLightSettingsOffDuration", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ LightSettings: &LightSettings{ Color: "#112233", LightOffDurationMillis: -1, }, }, }, Topic: "topic", }, want: "lightOffDuration must not be negative", }, { name: "InvalidVibrateTimings", req: &Message{ Android: &AndroidConfig{ Notification: &AndroidNotification{ VibrateTimingMillis: []int64{100, -1, 100}, }, }, Topic: "topic", }, want: "vibrateTimingMillis must not be negative", }, { name: "APNSMultipleAps", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ AlertString: "alert", }, CustomData: map[string]interface{}{ "aps": map[string]interface{}{"key": "value"}, }, }, }, Topic: "topic", }, want: `multiple specifications for the key "aps"`, }, { name: "APNSMultipleAlerts", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Alert: &ApsAlert{}, AlertString: "alert", }, }, }, Topic: "topic", }, want: "multiple alert specifications", }, { name: "APNSMultipleFieldSpecifications", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Category: "category", CustomData: map[string]interface{}{"category": "category"}, }, }, }, Topic: "topic", }, want: `multiple specifications for the key "category"`, }, { name: "InvalidAPNSTitleLocArgs", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Alert: &ApsAlert{ TitleLocArgs: []string{"a1"}, }, }, }, }, Topic: "topic", }, want: "titleLocKey is required when specifying titleLocArgs", }, { name: "InvalidAPNSSubTitleLocArgs", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Alert: &ApsAlert{ SubTitleLocArgs: []string{"a1"}, }, }, }, }, Topic: "topic", }, want: "subtitleLocKey is required when specifying subtitleLocArgs", }, { name: "InvalidAPNSLocArgs", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Alert: &ApsAlert{ LocArgs: []string{"a1"}, }, }, }, }, Topic: "topic", }, want: "locKey is required when specifying locArgs", }, { name: "InvalidAPNSImage", req: &Message{ APNS: &APNSConfig{ FCMOptions: &APNSFCMOptions{ ImageURL: "image.jpg", }, }, Topic: "topic", }, want: `invalid image URL: "image.jpg"`, }, { name: "MultipleSoundSpecifications", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ Sound: "s", CriticalSound: &CriticalSound{ Name: "s", }, }, }, }, Topic: "topic", }, want: "multiple sound specifications", }, { name: "VolumeTooLow", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ CriticalSound: &CriticalSound{ Name: "s", Volume: -0.1, }, }, }, }, Topic: "topic", }, want: "critical sound volume must be in the interval [0, 1]", }, { name: "VolumeTooHigh", req: &Message{ APNS: &APNSConfig{ Payload: &APNSPayload{ Aps: &Aps{ CriticalSound: &CriticalSound{ Name: "s", Volume: 1.1, }, }, }, }, Topic: "topic", }, want: "critical sound volume must be in the interval [0, 1]", }, { name: "InvalidWebpushNotificationDirection", req: &Message{ Webpush: &WebpushConfig{ Notification: &WebpushNotification{ Direction: "invalid", }, }, Topic: "topic", }, want: "direction must be 'ltr', 'rtl' or 'auto'", }, { name: "WebpushNotificationMultipleFieldSpecifications", req: &Message{ Webpush: &WebpushConfig{ Notification: &WebpushNotification{ Direction: "ltr", CustomData: map[string]interface{}{"dir": "rtl"}, }, }, Topic: "topic", }, want: `multiple specifications for the key "dir"`, }, { name: "InvalidWebpushFcmOptionsLink", req: &Message{ Webpush: &WebpushConfig{ Notification: &WebpushNotification{}, FCMOptions: &WebpushFCMOptions{ Link: "link", }, }, Topic: "topic", }, want: `invalid link URL: "link"`, }, { name: "InvalidWebpushFcmOptionsLinkScheme", req: &Message{ Webpush: &WebpushConfig{ Notification: &WebpushNotification{}, FCMOptions: &WebpushFCMOptions{ Link: "http://link.com", }, }, Topic: "topic", }, want: `invalid link URL: "http://link.com"; want scheme: "https"`, }, } func TestNoProjectID(t *testing.T) { client, err := NewClient(context.Background(), &internal.MessagingConfig{}) if client != nil || err == nil { t.Errorf("NewClient() = (%v, %v); want = (nil, error)", client, err) } } func TestJSONUnmarshal(t *testing.T) { for _, tc := range validMessages { if tc.name == "PrefixedTopicOnly" { continue } b, err := json.Marshal(tc.req) if err != nil { t.Errorf("Marshal(%s) = %v; want = nil", tc.name, err) } var target Message if err := json.Unmarshal(b, &target); err != nil { t.Errorf("Unmarshal(%s) = %v; want = nil", tc.name, err) } if !reflect.DeepEqual(tc.req, &target) { t.Errorf("Unmarshal(%s) result = %#v; want = %#v", tc.name, tc.req, target) } } } func TestInvalidJSONUnmarshal(t *testing.T) { cases := []struct { name string req map[string]interface{} target interface{} }{ { name: "InvalidTTLSegments", req: map[string]interface{}{ "ttl": "1.2.3s", }, target: &AndroidConfig{}, }, { name: "IncorrectTTLSeconds", req: map[string]interface{}{ "ttl": "abcs", }, target: &AndroidConfig{}, }, { name: "IncorrectTTLNanoseconds", req: map[string]interface{}{ "ttl": "10.abcs", }, target: &AndroidConfig{}, }, { name: "InvalidApsAlert", req: map[string]interface{}{ "alert": 10, }, target: &Aps{}, }, { name: "InvalidApsSound", req: map[string]interface{}{ "sound": 10, }, target: &Aps{}, }, { name: "InvalidPriority", req: map[string]interface{}{ "notification_priority": "invalid", }, target: &AndroidNotification{}, }, { name: "InvalidVisibility", req: map[string]interface{}{ "visibility": "invalid", }, target: &AndroidNotification{}, }, { name: "InvalidEventTimestamp", req: map[string]interface{}{ "event_time": "invalid", }, target: &AndroidNotification{}, }, { name: "IncorrectLightOnDuration", req: map[string]interface{}{ "light_on_duration": "10.abcs", }, target: &LightSettings{}, }, { name: "IncorrectLightOffDuration", req: map[string]interface{}{ "light_on_duration": "1s", "light_off_duration": "10.abcs", }, target: &LightSettings{}, }, } for _, tc := range cases { b, err := json.Marshal(tc.req) if err != nil { t.Errorf("Marshal(%s) = %v; want = nil", tc.name, err) } if err := json.Unmarshal(b, tc.target); err == nil { t.Errorf("Unmarshal(%s) = %v; want = error", tc.name, err) } } } func TestSend(t *testing.T) { var tr *http.Request var b []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r b, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{ \"name\":\"" + testMessageID + "\" }")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL for _, tc := range validMessages { t.Run(tc.name, func(t *testing.T) { name, err := client.Send(ctx, tc.req) if name != testMessageID || err != nil { t.Errorf("Send(%s) = (%q, %v); want = (%q, nil)", tc.name, name, err, testMessageID) } checkFCMRequest(t, b, tr, tc.want, false) }) } } func TestSendWithCustomEndpoint(t *testing.T) { var tr *http.Request var b []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r b, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{ \"name\":\"" + testMessageID + "\" }")) })) defer ts.Close() ctx := context.Background() conf := *testMessagingConfig optEndpoint := option.WithEndpoint(ts.URL) conf.Opts = append(conf.Opts, optEndpoint) client, err := NewClient(ctx, &conf) if err != nil { t.Fatal(err) } if ts.URL != client.fcmEndpoint { t.Errorf("client.fcmEndpoint = %q; want = %q", client.fcmEndpoint, ts.URL) } for _, tc := range validMessages { t.Run(tc.name, func(t *testing.T) { name, err := client.Send(ctx, tc.req) if name != testMessageID || err != nil { t.Errorf("Send(%s) = (%q, %v); want = (%q, nil)", tc.name, name, err, testMessageID) } checkFCMRequest(t, b, tr, tc.want, false) }) } } func TestSendDryRun(t *testing.T) { var tr *http.Request var b []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r b, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{ \"name\":\"" + testMessageID + "\" }")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL for _, tc := range validMessages { t.Run(tc.name, func(t *testing.T) { name, err := client.SendDryRun(ctx, tc.req) if name != testMessageID || err != nil { t.Errorf("SendDryRun(%s) = (%q, %v); want = (%q, nil)", tc.name, name, err, testMessageID) } checkFCMRequest(t, b, tr, tc.want, true) }) } } func TestSendError(t *testing.T) { var resp string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) w.Header().Set("Content-Type", "application/json") w.Write([]byte(resp)) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.fcmEndpoint = ts.URL client.fcmClient.httpClient.RetryConfig = nil for idx, tc := range httpErrors { resp = tc.resp name, err := client.Send(ctx, &Message{Topic: "topic"}) if err == nil || err.Error() != tc.want || !tc.check(err) { t.Errorf("Send(%d) = (%q, %v); want = (%q, %q)", idx, name, err, "", tc.want) } } } func TestInvalidMessage(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } for _, tc := range invalidMessages { t.Run(tc.name, func(t *testing.T) { name, err := client.Send(ctx, tc.req) if err == nil || err.Error() != tc.want { t.Errorf("Send(%s) = (%q, %v); want = (%q, %q)", tc.name, name, err, "", tc.want) } }) } } func checkFCMRequest(t *testing.T, b []byte, tr *http.Request, want map[string]interface{}, dryRun bool) { var parsed map[string]interface{} if err := json.Unmarshal(b, &parsed); err != nil { t.Fatal(err) } if !reflect.DeepEqual(parsed["message"], want) { t.Errorf("Body = %#v; want = %#v", parsed["message"], want) } validate, ok := parsed["validate_only"] if dryRun { if !ok || validate != true { t.Errorf("ValidateOnly = %v; want = true", validate) } } else if ok { t.Errorf("ValidateOnly = %v; want none", validate) } if tr.Method != http.MethodPost { t.Errorf("Method = %q; want = %q", tr.Method, http.MethodPost) } if tr.URL.Path != "/projects/test-project/messages:send" { t.Errorf("Path = %q; want = %q", tr.URL.Path, "/projects/test-project/messages:send") } if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") } if h := tr.Header.Get("X-GOOG-API-FORMAT-VERSION"); h != "2" { t.Errorf("X-GOOG-API-FORMAT-VERSION = %q; want = %q", h, "2") } clientVersion := "fire-admin-go/" + testMessagingConfig.Version if h := tr.Header.Get("X-FIREBASE-CLIENT"); h != clientVersion { t.Errorf("X-FIREBASE-CLIENT = %q; want = %q", h, clientVersion) } xGoogAPIClientHeader := internal.GetMetricsHeader(testMessagingConfig.Version) if h := tr.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { t.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } } var httpErrors = []struct { resp, want string check func(error) bool }{ { resp: "{}", want: "unexpected http response with status: 500\n{}", check: errorutils.IsInternal, }, { resp: "{\"error\": {\"status\": \"INVALID_ARGUMENT\", \"message\": \"test error\"}}", want: "test error", check: errorutils.IsInvalidArgument, }, { resp: "{\"error\": {\"status\": \"NOT_FOUND\", \"message\": \"test error\"}}", want: "test error", check: errorutils.IsNotFound, }, { resp: "{\"error\": {\"status\": \"RESOURCE_EXHAUSTED\", \"message\": \"test error\"}}", want: "test error", check: errorutils.IsResourceExhausted, }, { resp: "{\"error\": {\"status\": \"UNAVAILABLE\", \"message\": \"test error\"}}", want: "test error", check: errorutils.IsUnavailable, }, { resp: "{\"error\": {\"status\": \"INTERNAL\", \"message\": \"test error\"}}", want: "test error", check: errorutils.IsInternal, }, { resp: `{"error": {"status": "INVALID_ARGUMENT", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "UNREGISTERED"}]}}`, want: "test error", check: func(err error) bool { return IsRegistrationTokenNotRegistered(err) && IsUnregistered(err) }, }, { resp: `{"error": {"status": "INVALID_ARGUMENT", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "SENDER_ID_MISMATCH"}]}}`, want: "test error", check: func(err error) bool { return IsMismatchedCredential(err) && IsSenderIDMismatch(err) }, }, { resp: `{"error": {"status": "RESOURCE_EXHAUSTED", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "QUOTA_EXCEEDED"}]}}`, want: "test error", check: func(err error) bool { return IsMessageRateExceeded(err) && IsQuotaExceeded(err) }, }, { resp: `{"error": {"status": "UNAVAILABLE", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "UNAVAILABLE"}]}}`, want: "test error", check: func(err error) bool { return IsServerUnavailable(err) && IsUnavailable(err) }, }, { resp: `{"error": {"status": "INTERNAL", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "INTERNAL"}]}}`, want: "test error", check: IsInternal, }, { resp: `{"error": {"status": "INVALID_ARGUMENT", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "INVALID_ARGUMENT"}]}}`, want: "test error", check: IsInvalidArgument, }, { resp: `{"error": {"status": "INVALID_ARGUMENT", "message": "test error", "details": [` + `{"@type": "type.googleapis.com/google.firebase.fcm.v1.FcmError", "errorCode": "THIRD_PARTY_AUTH_ERROR"}]}}`, want: "test error", check: func(err error) bool { return IsInvalidAPNSCredentials(err) && IsThirdPartyAuthError(err) }, }, { resp: "not json", want: "unexpected http response with status: 500\nnot json", check: errorutils.IsInternal, }, } golang-google-firebase-go-4.18.0/messaging/messaging_utils.go000066400000000000000000000150771505612111400242150ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "errors" "fmt" "net/url" "regexp" "strings" ) var ( bareTopicNamePattern = regexp.MustCompile("^[a-zA-Z0-9-_.~%]+$") colorPattern = regexp.MustCompile("^#[0-9a-fA-F]{6}$") colorWithAlphaPattern = regexp.MustCompile("^#[0-9a-fA-F]{6}([0-9a-fA-F]{2})?$") ) func validateMessage(message *Message) error { if message == nil { return fmt.Errorf("message must not be nil") } targets := countNonEmpty(message.Token, message.Condition, message.Topic) if targets != 1 { return fmt.Errorf("exactly one of token, topic or condition must be specified") } // validate topic if message.Topic != "" { bt := strings.TrimPrefix(message.Topic, "/topics/") if !bareTopicNamePattern.MatchString(bt) { return fmt.Errorf("malformed topic name") } } // validate Notification if err := validateNotification(message.Notification); err != nil { return err } // validate AndroidConfig if err := validateAndroidConfig(message.Android); err != nil { return err } // validate WebpushConfig if err := validateWebpushConfig(message.Webpush); err != nil { return err } // validate APNSConfig return validateAPNSConfig(message.APNS) } func validateNotification(notification *Notification) error { if notification == nil { return nil } image := notification.ImageURL if image != "" { if _, err := url.ParseRequestURI(image); err != nil { return fmt.Errorf("invalid image URL: %q", image) } } return nil } func validateAndroidConfig(config *AndroidConfig) error { if config == nil { return nil } if config.TTL != nil && config.TTL.Seconds() < 0 { return fmt.Errorf("ttl duration must not be negative") } if config.Priority != "" && config.Priority != "normal" && config.Priority != "high" { return fmt.Errorf("priority must be 'normal' or 'high'") } // validate AndroidNotification return validateAndroidNotification(config.Notification) } func validateAndroidNotification(notification *AndroidNotification) error { if notification == nil { return nil } if notification.Color != "" && !colorPattern.MatchString(notification.Color) { return fmt.Errorf("color must be in the #RRGGBB form") } if len(notification.TitleLocArgs) > 0 && notification.TitleLocKey == "" { return fmt.Errorf("titleLocKey is required when specifying titleLocArgs") } if len(notification.BodyLocArgs) > 0 && notification.BodyLocKey == "" { return fmt.Errorf("bodyLocKey is required when specifying bodyLocArgs") } image := notification.ImageURL if image != "" { if _, err := url.ParseRequestURI(image); err != nil { return fmt.Errorf("invalid image URL: %q", image) } } for _, timing := range notification.VibrateTimingMillis { if timing < 0 { return fmt.Errorf("vibrateTimingMillis must not be negative") } } return validateLightSettings(notification.LightSettings) } func validateLightSettings(light *LightSettings) error { if light == nil { return nil } if !colorWithAlphaPattern.MatchString(light.Color) { return errors.New("color must be in #RRGGBB or #RRGGBBAA form") } if light.LightOnDurationMillis < 0 { return errors.New("lightOnDuration must not be negative") } if light.LightOffDurationMillis < 0 { return errors.New("lightOffDuration must not be negative") } return nil } func validateAPNSConfig(config *APNSConfig) error { if config != nil { // validate FCMOptions if config.FCMOptions != nil { image := config.FCMOptions.ImageURL if image != "" { if _, err := url.ParseRequestURI(image); err != nil { return fmt.Errorf("invalid image URL: %q", image) } } } return validateAPNSPayload(config.Payload) } return nil } func validateAPNSPayload(payload *APNSPayload) error { if payload != nil { m := payload.standardFields() for k := range payload.CustomData { if _, contains := m[k]; contains { return fmt.Errorf("multiple specifications for the key %q", k) } } return validateAps(payload.Aps) } return nil } func validateAps(aps *Aps) error { if aps != nil { if aps.Alert != nil && aps.AlertString != "" { return fmt.Errorf("multiple alert specifications") } if aps.CriticalSound != nil { if aps.Sound != "" { return fmt.Errorf("multiple sound specifications") } if aps.CriticalSound.Volume < 0 || aps.CriticalSound.Volume > 1 { return fmt.Errorf("critical sound volume must be in the interval [0, 1]") } } m := aps.standardFields() for k := range aps.CustomData { if _, contains := m[k]; contains { return fmt.Errorf("multiple specifications for the key %q", k) } } return validateApsAlert(aps.Alert) } return nil } func validateApsAlert(alert *ApsAlert) error { if alert == nil { return nil } if len(alert.TitleLocArgs) > 0 && alert.TitleLocKey == "" { return fmt.Errorf("titleLocKey is required when specifying titleLocArgs") } if len(alert.SubTitleLocArgs) > 0 && alert.SubTitleLocKey == "" { return fmt.Errorf("subtitleLocKey is required when specifying subtitleLocArgs") } if len(alert.LocArgs) > 0 && alert.LocKey == "" { return fmt.Errorf("locKey is required when specifying locArgs") } return nil } func validateWebpushConfig(webpush *WebpushConfig) error { if webpush == nil || webpush.Notification == nil { return nil } dir := webpush.Notification.Direction if dir != "" && dir != "ltr" && dir != "rtl" && dir != "auto" { return fmt.Errorf("direction must be 'ltr', 'rtl' or 'auto'") } m := webpush.Notification.standardFields() for k := range webpush.Notification.CustomData { if _, contains := m[k]; contains { return fmt.Errorf("multiple specifications for the key %q", k) } } if webpush.FCMOptions != nil { link := webpush.FCMOptions.Link p, err := url.ParseRequestURI(link) if err != nil { return fmt.Errorf("invalid link URL: %q", link) } else if p.Scheme != "https" { return fmt.Errorf("invalid link URL: %q; want scheme: %q", link, "https") } } return nil } func countNonEmpty(strings ...string) int { count := 0 for _, s := range strings { if s != "" { count++ } } return count } golang-google-firebase-go-4.18.0/messaging/topic_mgt.go000066400000000000000000000107621505612111400230010ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "context" "encoding/json" "fmt" "net/http" "strings" "firebase.google.com/go/v4/internal" ) const ( iidEndpoint = "https://iid.googleapis.com/iid/v1" iidSubscribe = "batchAdd" iidUnsubscribe = "batchRemove" ) // TopicManagementResponse is the result produced by topic management operations. // // TopicManagementResponse provides an overview of how many input tokens were successfully handled, // and how many failed. In case of failures, the Errors list provides specific details concerning // each error. type TopicManagementResponse struct { SuccessCount int FailureCount int Errors []*ErrorInfo } func newTopicManagementResponse(resp *iidResponse) *TopicManagementResponse { tmr := &TopicManagementResponse{} for idx, res := range resp.Results { if len(res) == 0 { tmr.SuccessCount++ } else { tmr.FailureCount++ reason := res["error"].(string) tmr.Errors = append(tmr.Errors, &ErrorInfo{ Index: idx, Reason: reason, }) } } return tmr } type iidClient struct { iidEndpoint string httpClient *internal.HTTPClient } func newIIDClient(hc *http.Client, conf *internal.MessagingConfig) *iidClient { client := internal.WithDefaultRetryConfig(hc) client.CreateErrFn = handleIIDError client.Opts = []internal.HTTPOption{ internal.WithHeader("access_token_auth", "true"), internal.WithHeader("x-goog-api-client", internal.GetMetricsHeader(conf.Version)), } return &iidClient{ iidEndpoint: iidEndpoint, httpClient: client, } } // SubscribeToTopic subscribes a list of registration tokens to a topic. // // The tokens list must not be empty, and have at most 1000 tokens. func (c *iidClient) SubscribeToTopic(ctx context.Context, tokens []string, topic string) (*TopicManagementResponse, error) { req := &iidRequest{ Topic: topic, Tokens: tokens, op: iidSubscribe, } return c.makeTopicManagementRequest(ctx, req) } // UnsubscribeFromTopic unsubscribes a list of registration tokens from a topic. // // The tokens list must not be empty, and have at most 1000 tokens. func (c *iidClient) UnsubscribeFromTopic(ctx context.Context, tokens []string, topic string) (*TopicManagementResponse, error) { req := &iidRequest{ Topic: topic, Tokens: tokens, op: iidUnsubscribe, } return c.makeTopicManagementRequest(ctx, req) } type iidRequest struct { Topic string `json:"to"` Tokens []string `json:"registration_tokens"` op string } type iidResponse struct { Results []map[string]interface{} `json:"results"` } type iidErrorResponse struct { Error string `json:"error"` } func (c *iidClient) makeTopicManagementRequest(ctx context.Context, req *iidRequest) (*TopicManagementResponse, error) { if len(req.Tokens) == 0 { return nil, fmt.Errorf("no tokens specified") } if len(req.Tokens) > 1000 { return nil, fmt.Errorf("tokens list must not contain more than 1000 items") } for _, token := range req.Tokens { if token == "" { return nil, fmt.Errorf("tokens list must not contain empty strings") } } if req.Topic == "" { return nil, fmt.Errorf("topic name not specified") } if !topicNamePattern.MatchString(req.Topic) { return nil, fmt.Errorf("invalid topic name: %q", req.Topic) } if !strings.HasPrefix(req.Topic, "/topics/") { req.Topic = "/topics/" + req.Topic } request := &internal.Request{ Method: http.MethodPost, URL: fmt.Sprintf("%s:%s", c.iidEndpoint, req.op), Body: internal.NewJSONEntity(req), } var result iidResponse if _, err := c.httpClient.DoAndUnmarshal(ctx, request, &result); err != nil { return nil, err } return newTopicManagementResponse(&result), nil } func handleIIDError(resp *internal.Response) error { base := internal.NewFirebaseError(resp) var ie iidErrorResponse json.Unmarshal(resp.Body, &ie) // ignore any json parse errors at this level if ie.Error != "" { base.String = fmt.Sprintf("error while calling the iid service: %s", ie.Error) } return base } golang-google-firebase-go-4.18.0/messaging/topic_mgt_test.go000066400000000000000000000160731505612111400240410ustar00rootroot00000000000000// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messaging import ( "context" "encoding/json" "io/ioutil" "net/http" "net/http/httptest" "reflect" "strings" "testing" "firebase.google.com/go/v4/errorutils" "firebase.google.com/go/v4/internal" ) func TestSubscribe(t *testing.T) { var tr *http.Request var b []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r b, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{\"results\": [{}, {\"error\": \"error_reason\"}]}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.iidEndpoint = ts.URL + "/v1" resp, err := client.SubscribeToTopic(ctx, []string{"id1", "id2"}, "test-topic") if err != nil { t.Fatal(err) } checkIIDRequest(t, b, tr, iidSubscribe) checkTopicMgtResponse(t, resp) } func TestInvalidSubscribe(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } for _, tc := range invalidTopicMgtArgs { t.Run(tc.name, func(t *testing.T) { resp, err := client.SubscribeToTopic(ctx, tc.tokens, tc.topic) if err == nil || err.Error() != tc.want { t.Errorf( "SubscribeToTopic(%s) = (%#v, %v); want = (nil, %q)", tc.name, resp, err, tc.want) } }) } } func TestUnsubscribe(t *testing.T) { var tr *http.Request var b []byte ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr = r b, _ = ioutil.ReadAll(r.Body) w.Header().Set("Content-Type", "application/json") w.Write([]byte("{\"results\": [{}, {\"error\": \"error_reason\"}]}")) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.iidEndpoint = ts.URL + "/v1" resp, err := client.UnsubscribeFromTopic(ctx, []string{"id1", "id2"}, "test-topic") if err != nil { t.Fatal(err) } checkIIDRequest(t, b, tr, iidUnsubscribe) checkTopicMgtResponse(t, resp) } func TestInvalidUnsubscribe(t *testing.T) { ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } for _, tc := range invalidTopicMgtArgs { t.Run(tc.name, func(t *testing.T) { resp, err := client.UnsubscribeFromTopic(ctx, tc.tokens, tc.topic) if err == nil || err.Error() != tc.want { t.Errorf( "UnsubscribeFromTopic(%s) = (%#v, %v); want = (nil, %q)", tc.name, resp, err, tc.want) } }) } } func TestTopicManagementError(t *testing.T) { var resp string var status int ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(status) w.Header().Set("Content-Type", "application/json") w.Write([]byte(resp)) })) defer ts.Close() ctx := context.Background() client, err := NewClient(ctx, testMessagingConfig) if err != nil { t.Fatal(err) } client.iidEndpoint = ts.URL + "/v1" client.iidClient.httpClient.RetryConfig = nil cases := []struct { name, resp, want string status int check func(err error) bool }{ { name: "EmptyResponse", resp: "{}", want: "unexpected http response with status: 500\n{}", status: http.StatusInternalServerError, check: errorutils.IsInternal, }, { name: "ErrorCode", resp: "{\"error\": \"INVALID_ARGUMENT\"}", want: "error while calling the iid service: INVALID_ARGUMENT", status: http.StatusBadRequest, check: errorutils.IsInvalidArgument, }, { name: "NotJson", resp: "not json", want: "unexpected http response with status: 500\nnot json", status: http.StatusInternalServerError, check: errorutils.IsInternal, }, } for _, tc := range cases { resp = tc.resp status = tc.status tmr, err := client.SubscribeToTopic(ctx, []string{"id1"}, "topic") if err == nil || err.Error() != tc.want || !tc.check(err) { t.Errorf("SubscribeToTopic(%s) = (%#v, %v); want = (nil, %q)", tc.name, tmr, err, tc.want) } tmr, err = client.UnsubscribeFromTopic(ctx, []string{"id1"}, "topic") if err == nil || err.Error() != tc.want || !tc.check(err) { t.Errorf("UnsubscribeFromTopic(%s) = (%#v, %v); want = (nil, %q)", tc.name, tmr, err, tc.want) } } } func checkIIDRequest(t *testing.T, b []byte, tr *http.Request, op string) { var parsed map[string]interface{} if err := json.Unmarshal(b, &parsed); err != nil { t.Fatal(err) } want := map[string]interface{}{ "to": "/topics/test-topic", "registration_tokens": []interface{}{"id1", "id2"}, } if !reflect.DeepEqual(parsed, want) { t.Errorf("Body = %#v; want = %#v", parsed, want) } if tr.Method != http.MethodPost { t.Errorf("Method = %q; want = %q", tr.Method, http.MethodPost) } wantOp := "/v1:" + op if tr.URL.Path != wantOp { t.Errorf("Path = %q; want = %q", tr.URL.Path, wantOp) } if h := tr.Header.Get("Authorization"); h != "Bearer test-token" { t.Errorf("Authorization = %q; want = %q", h, "Bearer test-token") } xGoogAPIClientHeader := internal.GetMetricsHeader(testMessagingConfig.Version) if h := tr.Header.Get("x-goog-api-client"); h != xGoogAPIClientHeader { t.Errorf("x-goog-api-client header = %q; want = %q", h, xGoogAPIClientHeader) } } func checkTopicMgtResponse(t *testing.T, resp *TopicManagementResponse) { if resp.SuccessCount != 1 { t.Errorf("SuccessCount = %d; want = %d", resp.SuccessCount, 1) } if resp.FailureCount != 1 { t.Errorf("FailureCount = %d; want = %d", resp.FailureCount, 1) } if len(resp.Errors) != 1 { t.Fatalf("Errors = %d; want = %d", len(resp.Errors), 1) } e := resp.Errors[0] if e.Index != 1 { t.Errorf("ErrorInfo.Index = %d; want = %d", e.Index, 1) } if e.Reason != "error_reason" { t.Errorf("ErrorInfo.Reason = %s; want = %s", e.Reason, "error_reason") } } var invalidTopicMgtArgs = []struct { name string tokens []string topic string want string }{ { name: "NoTokensAndTopic", want: "no tokens specified", }, { name: "NoTopic", tokens: []string{"token1"}, want: "topic name not specified", }, { name: "InvalidTopicName", tokens: []string{"token1"}, topic: "foo*bar", want: "invalid topic name: \"foo*bar\"", }, { name: "TooManyTokens", tokens: strings.Split("a"+strings.Repeat(",a", 1000), ","), topic: "topic", want: "tokens list must not contain more than 1000 items", }, { name: "EmptyToken", tokens: []string{"foo", ""}, topic: "topic", want: "tokens list must not contain empty strings", }, } golang-google-firebase-go-4.18.0/remoteconfig/000077500000000000000000000000001505612111400211635ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/remoteconfig/condition_evaluator.go000066400000000000000000000322041505612111400255630ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig import ( "crypto/sha256" "encoding/json" "errors" "fmt" "log" "math/big" "regexp" "strconv" "strings" ) type conditionEvaluator struct { evaluationContext map[string]any conditions []namedCondition } const ( maxConditionRecursionDepth = 10 rootNestingLevel = 0 doublePrecision = 64 whiteSpace = " " segmentSeparator = "." maxPossibleSegments = 5 ) var ( errTooManySegments = errors.New("number of segments exceeds maximum allowed length") errNegativeSegment = errors.New("segment cannot be negative") errInvalidCustomSignal = errors.New("missing operator, key, or target values for custom signal condition") ) const ( randomizationID = "randomizationID" totalMicroPercentiles = 100_000_000 lessThanOrEqual = "LESS_OR_EQUAL" greaterThan = "GREATER_THAN" between = "BETWEEN" ) const ( stringContains = "STRING_CONTAINS" stringDoesNotContain = "STRING_DOES_NOT_CONTAIN" stringExactlyMatches = "STRING_EXACTLY_MATCHES" stringContainsRegex = "STRING_CONTAINS_REGEX" numericLessThan = "NUMERIC_LESS_THAN" numericLessThanEqual = "NUMERIC_LESS_EQUAL" numericEqual = "NUMERIC_EQUAL" numericNotEqual = "NUMERIC_NOT_EQUAL" numericGreaterThan = "NUMERIC_GREATER_THAN" numericGreaterEqual = "NUMERIC_GREATER_EQUAL" semanticVersionLessThan = "SEMANTIC_VERSION_LESS_THAN" semanticVersionLessEqual = "SEMANTIC_VERSION_LESS_EQUAL" semanticVersionEqual = "SEMANTIC_VERSION_EQUAL" semanticVersionNotEqual = "SEMANTIC_VERSION_NOT_EQUAL" semanticVersionGreaterThan = "SEMANTIC_VERSION_GREATER_THAN" semanticVersionGreaterEqual = "SEMANTIC_VERSION_GREATER_EQUAL" ) func (ce *conditionEvaluator) evaluateConditions() map[string]bool { evaluatedConditions := make(map[string]bool) for _, condition := range ce.conditions { evaluatedConditions[condition.Name] = ce.evaluateCondition(condition.Condition, rootNestingLevel) } return evaluatedConditions } func (ce *conditionEvaluator) evaluateCondition(condition *oneOfCondition, nestingLevel int) bool { if nestingLevel >= maxConditionRecursionDepth { log.Println("Maximum recursion depth is exceeded.") return false } if condition.Boolean != nil { return *condition.Boolean } else if condition.OrCondition != nil { return ce.evaluateOrCondition(condition.OrCondition, nestingLevel+1) } else if condition.AndCondition != nil { return ce.evaluateAndCondition(condition.AndCondition, nestingLevel+1) } else if condition.Percent != nil { return ce.evaluatePercentCondition(condition.Percent) } else if condition.CustomSignal != nil { return ce.evaluateCustomSignalCondition(condition.CustomSignal) } log.Println("Unknown condition type encountered.") return false } func (ce *conditionEvaluator) evaluateOrCondition(orCondition *orCondition, nestingLevel int) bool { for _, condition := range orCondition.Conditions { result := ce.evaluateCondition(&condition, nestingLevel+1) if result { return true } } return false } func (ce *conditionEvaluator) evaluateAndCondition(andCondition *andCondition, nestingLevel int) bool { for _, condition := range andCondition.Conditions { result := ce.evaluateCondition(&condition, nestingLevel+1) if !result { return false } } return true } func (ce *conditionEvaluator) evaluatePercentCondition(percentCondition *percentCondition) bool { if rid, ok := ce.evaluationContext[randomizationID].(string); ok { if percentCondition.PercentOperator == "" { log.Println("Missing percent operator for percent condition.") return false } instanceMicroPercentile := computeInstanceMicroPercentile(percentCondition.Seed, rid) switch percentCondition.PercentOperator { case lessThanOrEqual: return instanceMicroPercentile <= percentCondition.MicroPercent case greaterThan: return instanceMicroPercentile > percentCondition.MicroPercent case between: return instanceMicroPercentile > percentCondition.MicroPercentRange.MicroPercentLowerBound && instanceMicroPercentile <= percentCondition.MicroPercentRange.MicroPercentUpperBound default: log.Printf("Unknown percent operator: %s\n", percentCondition.PercentOperator) return false } } log.Println("Missing or invalid randomizationID (requires a string value) for percent condition.") return false } func computeInstanceMicroPercentile(seed string, randomizationID string) uint32 { var sb strings.Builder if len(seed) > 0 { sb.WriteString(seed) sb.WriteRune('.') } sb.WriteString(randomizationID) stringToHash := sb.String() hash := sha256.New() hash.Write([]byte(stringToHash)) // Calculate the final SHA-256 hash as a byte slice (32 bytes). // Convert to a big.Int. The "0x" prefix is implicit in the conversion from hex to big.Int. hashBigInt := new(big.Int).SetBytes(hash.Sum(nil)) instanceMicroPercentileBigInt := new(big.Int).Mod(hashBigInt, big.NewInt(totalMicroPercentiles)) // Safely convert to uint32 since the range of instanceMicroPercentile is 0 to 100_000_000; range of uint32 is 0 to 4_294_967_295. return uint32(instanceMicroPercentileBigInt.Int64()) } func (ce *conditionEvaluator) evaluateCustomSignalCondition(customSignalCondition *customSignalCondition) bool { if err := customSignalCondition.isValid(); err != nil { log.Println(err) return false } actualValue, ok := ce.evaluationContext[customSignalCondition.CustomSignalKey] if !ok { log.Printf("Custom signal key: %s, missing from context\n", customSignalCondition.CustomSignalKey) return false } switch customSignalCondition.CustomSignalOperator { case stringContains: return compareStrings(customSignalCondition.TargetCustomSignalValues, actualValue, func(actualValue, target string) bool { return strings.Contains(actualValue, target) }) case stringDoesNotContain: return !compareStrings(customSignalCondition.TargetCustomSignalValues, actualValue, func(actualValue, target string) bool { return strings.Contains(actualValue, target) }) case stringExactlyMatches: return compareStrings(customSignalCondition.TargetCustomSignalValues, actualValue, func(actualValue, target string) bool { return strings.Trim(actualValue, whiteSpace) == strings.Trim(target, whiteSpace) }) case stringContainsRegex: return compareStrings(customSignalCondition.TargetCustomSignalValues, actualValue, func(actualValue, targetPattern string) bool { result, err := regexp.MatchString(targetPattern, actualValue) if err != nil { return false } return result }) // For numeric operators only one target value is allowed. case numericLessThan: return compareNumbers(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result < 0 }) case numericLessThanEqual: return compareNumbers(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result <= 0 }) case numericEqual: return compareNumbers(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result == 0 }) case numericNotEqual: return compareNumbers(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result != 0 }) case numericGreaterThan: return compareNumbers(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result > 0 }) case numericGreaterEqual: return compareNumbers(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result >= 0 }) // For semantic operators only one target value is allowed. case semanticVersionLessThan: return compareSemanticVersion(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result < 0 }) case semanticVersionLessEqual: return compareSemanticVersion(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result <= 0 }) case semanticVersionEqual: return compareSemanticVersion(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result == 0 }) case semanticVersionNotEqual: return compareSemanticVersion(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result != 0 }) case semanticVersionGreaterThan: return compareSemanticVersion(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result > 0 }) case semanticVersionGreaterEqual: return compareSemanticVersion(customSignalCondition.TargetCustomSignalValues[0], actualValue, func(result int) bool { return result >= 0 }) } log.Printf("Unknown custom signal operator: %s\n", customSignalCondition.CustomSignalOperator) return false } func (cs *customSignalCondition) isValid() error { if cs.CustomSignalOperator == "" || cs.CustomSignalKey == "" || len(cs.TargetCustomSignalValues) == 0 { return errInvalidCustomSignal } return nil } func compareStrings(targetCustomSignalValues []string, actualValue any, predicateFn func(actualValue, target string) bool) bool { csValStr, ok := actualValue.(string) if !ok { if jsonBytes, err := json.Marshal(actualValue); err == nil { csValStr = string(jsonBytes) } else { log.Printf("Failed to parse custom signal value '%v' as a string : %v\n", actualValue, err) return false } } for _, target := range targetCustomSignalValues { if predicateFn(csValStr, target) { return true } } return false } func compareNumbers(targetCustomSignalValue string, actualValue any, predicateFn func(result int) bool) bool { targetFloat, err := strconv.ParseFloat(strings.Trim(targetCustomSignalValue, whiteSpace), doublePrecision) if err != nil { log.Printf("Failed to convert target custom signal value '%v' from string to number: %v", targetCustomSignalValue, err) return false } var actualValFloat float64 switch actualValue := actualValue.(type) { case float32: actualValFloat = float64(actualValue) case float64: actualValFloat = actualValue case int8: actualValFloat = float64(actualValue) case int: actualValFloat = float64(actualValue) case int16: actualValFloat = float64(actualValue) case int32: actualValFloat = float64(actualValue) case int64: actualValFloat = float64(actualValue) case uint8: actualValFloat = float64(actualValue) case uint: actualValFloat = float64(actualValue) case uint16: actualValFloat = float64(actualValue) case uint32: actualValFloat = float64(actualValue) case uint64: actualValFloat = float64(actualValue) case bool: if actualValue { actualValFloat = 1 } else { actualValFloat = 0 } case string: actualValFloat, err = strconv.ParseFloat(strings.Trim(actualValue, whiteSpace), doublePrecision) if err != nil { log.Printf("Failed to convert custom signal value '%v' from string to number: %v", actualValue, err) return false } default: log.Printf("Cannot parse custom signal value '%v' of type %T as a number", actualValue, actualValue) return false } result := 0 if actualValFloat > targetFloat { result = 1 } else if actualValFloat < targetFloat { result = -1 } return predicateFn(result) } func compareSemanticVersion(targetValue string, actualValue any, predicateFn func(result int) bool) bool { targetSemVer, err := transformVersionToSegments(strings.Trim(targetValue, whiteSpace)) if err != nil { log.Printf("Error transforming target semantic version %q: %v\n", targetValue, err) return false } actualValueStr := fmt.Sprintf("%v", actualValue) actualSemVer, err := transformVersionToSegments(strings.Trim(actualValueStr, whiteSpace)) if err != nil { log.Printf("Error transforming custom signal value '%v' to semantic version: %v\n", actualValue, err) return false } for idx := 0; idx < maxPossibleSegments; idx++ { if actualSemVer[idx] > targetSemVer[idx] { return predicateFn(1) } else if actualSemVer[idx] < targetSemVer[idx] { return predicateFn(-1) } } return predicateFn(0) } func transformVersionToSegments(version string) ([]int, error) { // Trim any trailing or leading segment separators (.) and split. trimmedVersion := strings.Trim(version, segmentSeparator) segments := strings.Split(trimmedVersion, segmentSeparator) if len(segments) > maxPossibleSegments { return nil, errTooManySegments } // Initialize with the maximum possible segment length for consistent comparison. transformedVersion := make([]int, maxPossibleSegments) for idx, segmentStr := range segments { segmentInt, err := strconv.Atoi(segmentStr) if err != nil { return nil, err } if segmentInt < 0 { return nil, errNegativeSegment } transformedVersion[idx] = segmentInt } return transformedVersion, nil } golang-google-firebase-go-4.18.0/remoteconfig/condition_evaluator_test.go000066400000000000000000001004731505612111400266260ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig import ( "encoding/json" "errors" "fmt" "reflect" "strconv" "strings" "testing" ) const ( isEnabled = "is_enabled" customSignalKey = "customSignalKey" premium = "premium" testRandomizationID = "123" testSeed = "abcdef" leadingWhiteSpaceCountTarget = 3 trailingWhiteSpaceCountTarget = 5 leadingWhiteSpaceCountActual = 4 trailingWhiteSpaceCountActual = 2 ) type customSignalTestCase struct { targets string actual any outcome bool } func createNamedCondition(name string, condition oneOfCondition) namedCondition { nc := namedCondition{ Name: name, Condition: &condition, } return nc } func evaluateConditionsAndReportResult(t *testing.T, nc namedCondition, conditionName string, context map[string]any, outcome bool) { ce := conditionEvaluator{ conditions: []namedCondition{nc}, evaluationContext: context, } ec := ce.evaluateConditions() value, ok := ec[conditionName] if !ok { t.Fatalf("condition %q was not found in evaluated conditions", conditionName) } if value != outcome { t.Errorf("condition evaluation for %q = %v, want = %v", conditionName, value, outcome) } } // Returns the number of assignments which evaluate to true for the specified percent condition. // This method randomly generates the ids for each assignment for this purpose. func evaluateRandomAssignments(numOfAssignments int, condition namedCondition) int { evalTrueCount := 0 for i := 0; i < numOfAssignments; i++ { context := map[string]any{randomizationID: fmt.Sprintf("random-%d", i)} ce := conditionEvaluator{ conditions: []namedCondition{condition}, evaluationContext: context, } ec := ce.evaluateConditions() if value, ok := ec[isEnabled]; ok && value { evalTrueCount++ } } return evalTrueCount } func runCustomSignalTestCase(operator string, t *testing.T) func(customSignalTestCase) { return func(tc customSignalTestCase) { description := fmt.Sprintf("Evaluates operator %v with targets %v and actual %v to outcome %v", operator, tc.targets, tc.actual, tc.outcome) t.Run(description, func(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ CustomSignal: &customSignalCondition{ CustomSignalOperator: operator, CustomSignalKey: customSignalKey, TargetCustomSignalValues: strings.Split(tc.targets, ","), }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{customSignalKey: tc.actual}, tc.outcome) }) } } func runCustomSignalTestCaseWithWhiteSpaces(operator string, t *testing.T) func(customSignalTestCase) { return func(tc customSignalTestCase) { targetsWithWhiteSpaces := []string{} for _, target := range strings.Split(tc.targets, ",") { targetsWithWhiteSpaces = append(targetsWithWhiteSpaces, addLeadingAndTrailingWhiteSpaces(target, leadingWhiteSpaceCountTarget, trailingWhiteSpaceCountTarget)) } runCustomSignalTestCase(operator, t)(customSignalTestCase{ outcome: tc.outcome, actual: addLeadingAndTrailingWhiteSpaces(tc.actual, leadingWhiteSpaceCountActual, trailingWhiteSpaceCountActual), targets: strings.Join(targetsWithWhiteSpaces, ","), }) } } func addLeadingAndTrailingWhiteSpaces(v any, leadingSpacesCount int, trailingSpacesCount int) string { vStr, ok := v.(string) if !ok { if jsonBytes, err := json.Marshal(v); err == nil { vStr = string(jsonBytes) } } return strings.Repeat(whiteSpace, leadingSpacesCount) + vStr + strings.Repeat(whiteSpace, trailingSpacesCount) } func TestEvaluateEmptyOrCondition(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ OrCondition: &orCondition{}, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, false) } func TestEvaluateEmptyOrAndCondition(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ OrCondition: &orCondition{ Conditions: []oneOfCondition{ { AndCondition: &andCondition{}, }, }, }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, true) } func TestEvaluateOrConditionShortCircuit(t *testing.T) { boolFalse := false boolTrue := true condition := createNamedCondition(isEnabled, oneOfCondition{ OrCondition: &orCondition{ Conditions: []oneOfCondition{ { Boolean: &boolFalse, }, { Boolean: &boolTrue, }, { Boolean: &boolFalse, }, }, }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, true) } func TestEvaluateAndConditionShortCircuit(t *testing.T) { boolFalse := false boolTrue := true condition := createNamedCondition(isEnabled, oneOfCondition{ AndCondition: &andCondition{ Conditions: []oneOfCondition{ { Boolean: &boolTrue, }, { Boolean: &boolFalse, }, { Boolean: &boolTrue, }, }, }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, false) } func TestPercentConditionWithoutRandomizationId(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: between, Seed: testSeed, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 0, MicroPercentUpperBound: 1_000_000, }, }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, false) } func TestUnknownPercentOperator(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: "UNKNOWN", Seed: testSeed, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 0, MicroPercentUpperBound: 1_000_000, }, }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, false) } func TestEmptyPercentOperator(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ Seed: testSeed, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 0, MicroPercentUpperBound: 1_000_000, }, }, }) evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{}, false) } func TestInvalidRandomizationIdType(t *testing.T) { // randomizationID is expected to be a string condition := createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ Seed: testSeed, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 0, MicroPercentUpperBound: 1_000_000, }, }, }) invalidRandomizationIDTestCases := []struct { randomizationID any }{ {randomizationID: 123}, {randomizationID: true}, {randomizationID: 123.4}, {randomizationID: "{\"hello\": \"world\"}"}, } for _, tc := range invalidRandomizationIDTestCases { description := fmt.Sprintf("RandomizationId %v of type %s", tc.randomizationID, reflect.TypeOf(tc.randomizationID)) t.Run(description, func(t *testing.T) { evaluateConditionsAndReportResult(t, condition, isEnabled, map[string]any{randomizationID: tc.randomizationID}, false) }) } } func TestInstanceMicroPercentileComputation(t *testing.T) { percentTestCases := []struct { seed string randomizationID string expectedMicroPercentile uint32 }{ {seed: "1", randomizationID: "one", expectedMicroPercentile: 64146488}, {seed: "2", randomizationID: "two", expectedMicroPercentile: 76516209}, {seed: "3", randomizationID: "three", expectedMicroPercentile: 6701947}, {seed: "4", randomizationID: "four", expectedMicroPercentile: 85000289}, {seed: "5", randomizationID: "five", expectedMicroPercentile: 2514745}, {seed: "", randomizationID: "😊", expectedMicroPercentile: 9911325}, {seed: "", randomizationID: "😀", expectedMicroPercentile: 62040281}, {seed: "hêl£o", randomizationID: "wørlÐ", expectedMicroPercentile: 67411682}, {seed: "řemøťe", randomizationID: "çōnfįġ", expectedMicroPercentile: 19728496}, {seed: "long", randomizationID: strings.Repeat(".", 100), expectedMicroPercentile: 39278120}, {seed: "very-long", randomizationID: strings.Repeat(".", 1000), expectedMicroPercentile: 71699042}, } for _, tc := range percentTestCases { description := fmt.Sprintf("Instance micro-percentile for seed %s & randomization_id %s", tc.seed, tc.randomizationID) t.Run(description, func(t *testing.T) { actualMicroPercentile := computeInstanceMicroPercentile(tc.seed, tc.randomizationID) if tc.expectedMicroPercentile != actualMicroPercentile { t.Errorf("instanceMicroPercentile = %d, want %d", actualMicroPercentile, tc.expectedMicroPercentile) } }) } } func TestPercentConditionMicroPercent(t *testing.T) { microPercentTestCases := []struct { description string operator string microPercent uint32 outcome bool }{ { description: "Evaluate LESS_OR_EQUAL to true when MicroPercent is max", operator: lessThanOrEqual, microPercent: 100_000_000, outcome: true, }, { description: "Evaluate LESS_OR_EQUAL to false when MicroPercent is min", operator: lessThanOrEqual, microPercent: 0, outcome: false, }, { description: "Evaluate LESS_OR_EQUAL to false when MicroPercent is not set (MicroPercent should use zero)", operator: lessThanOrEqual, outcome: false, }, { description: "Evaluate GREATER_THAN to true when MicroPercent is not set (MicroPercent should use zero)", operator: greaterThan, outcome: true, }, { description: "Evaluate GREATER_THAN max to false", operator: greaterThan, outcome: false, microPercent: 100_000_000, }, { description: "Evaluate LESS_OR_EQUAL to 9571542 to true", operator: lessThanOrEqual, microPercent: 9_571_542, // instanceMicroPercentile of abcdef.123 (testSeed.testRandomizationID) is 9_571_542 outcome: true, }, { description: "Evaluate greater than 9571542 to true", operator: greaterThan, microPercent: 9_571_541, // instanceMicroPercentile of abcdef.123 (testSeed.testRandomizationID) is 9_571_542 outcome: true, }, } for _, tc := range microPercentTestCases { t.Run(tc.description, func(t *testing.T) { percentCondition := createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: tc.operator, MicroPercent: tc.microPercent, Seed: testSeed, }, }) evaluateConditionsAndReportResult(t, percentCondition, isEnabled, map[string]any{"randomizationID": testRandomizationID}, tc.outcome) }) } } func TestPercentConditionMicroPercentRange(t *testing.T) { // These tests verify that the percentage-based conditions correctly target the intended proportion of users over many random evaluations. // The results are checked against expected statistical distributions to ensure accuracy within a defined tolerance (3 standard deviations). microPercentTestCases := []struct { description string operator string microPercentLb uint32 microPercentUb uint32 outcome bool }{ { description: "Evaluate to false when microPercentRange is not set", operator: between, outcome: false, }, { description: "Evaluate to false when upper bound is not set", microPercentLb: 0, operator: between, outcome: false, }, { description: "Evaluate to true when lower bound is not set and upper bound is max", microPercentUb: 100_000_000, operator: between, outcome: true, }, { description: "Evaluate to true when between lower and upper bound", // instanceMicroPercentile of abcdef.123 (testSeed.testRandomizationID) is 9_571_542 microPercentLb: 9_000_000, microPercentUb: 9_571_542, // interval is (9_000_000, 9_571_542] operator: between, outcome: true, }, { description: "Evaluate to false when lower and upper bounds are equal", microPercentLb: 98_000_000, microPercentUb: 98_000_000, operator: between, outcome: false, }, { description: "Evaluate to false when not between 9_400_000 and 9_500_000", // instanceMicroPercentile of abcdef.123 (testSeed.testRandomizationID) is 9_571_542 microPercentLb: 9_400_000, microPercentUb: 9_500_000, operator: between, outcome: false, }, } for _, tc := range microPercentTestCases { t.Run(tc.description, func(t *testing.T) { percentCondition := createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: tc.operator, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: tc.microPercentLb, MicroPercentUpperBound: tc.microPercentUb, }, Seed: testSeed, }, }) evaluateConditionsAndReportResult(t, percentCondition, isEnabled, map[string]any{randomizationID: testRandomizationID}, tc.outcome) }) } } // Statistically validates that percentage conditions accurately target the intended proportion of users over many random evaluations. func TestPercentConditionProbabilisticEvaluation(t *testing.T) { probabilisticEvalTestCases := []struct { description string condition namedCondition assignments int baseline int tolerance int }{ { description: "Evaluate less or equal to 10% to approx 10%", condition: createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: lessThanOrEqual, MicroPercent: 10_000_000, }, }), assignments: 100_000, baseline: 10000, tolerance: 284, // 284 is 3 standard deviations for 100k trials with 10% probability. }, { description: "Evaluate between 0 to 10% to approx 10%", condition: createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: between, MicroPercentRange: microPercentRange{ MicroPercentUpperBound: 10_000_000, }, }, }), assignments: 100_000, baseline: 10000, tolerance: 284, // 284 is 3 standard deviations for 100k trials with 10% probability. }, { description: "Evaluate greater than 10% to approx 90%", condition: createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: greaterThan, MicroPercent: 10_000_000, }, }), assignments: 100_000, baseline: 90000, tolerance: 284, // 284 is 3 standard deviations for 100k trials with 90% probability. }, { description: "Evaluate between 40% to 60% to approx 20%", condition: createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: between, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 40_000_000, MicroPercentUpperBound: 60_000_000, }, }, }), assignments: 100_000, baseline: 20000, tolerance: 379, // 379 is 3 standard deviations for 100k trials with 20% probability. }, { description: "Evaluate between interquartile range to approx 50%", condition: createNamedCondition(isEnabled, oneOfCondition{ Percent: &percentCondition{ PercentOperator: between, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 25_000_000, MicroPercentUpperBound: 75_000_000, }, }, }), assignments: 100_000, baseline: 50000, tolerance: 474, // 474 is 3 standard deviations for 100k trials with 50% probability. }, } for _, tc := range probabilisticEvalTestCases { t.Run(tc.description, func(t *testing.T) { truthyAssignments := evaluateRandomAssignments(tc.assignments, tc.condition) lessThan := truthyAssignments <= tc.baseline+tc.tolerance greaterThan := truthyAssignments >= tc.baseline-tc.tolerance outcome := lessThan && greaterThan if outcome != true { t.Errorf("Incorrect probabilistic evaluation: got %d true assignments, want between %d and %d (baseline %d, tolerance %d)", truthyAssignments, tc.baseline-tc.tolerance, tc.baseline+tc.tolerance, tc.baseline, tc.tolerance) } }) } } func TestCustomSignalConditionIsValid(t *testing.T) { testCases := []struct { description string condition customSignalCondition expected error }{ { description: "Valid condition", condition: customSignalCondition{ CustomSignalOperator: stringExactlyMatches, CustomSignalKey: customSignalKey, TargetCustomSignalValues: []string{premium}, }, expected: nil, }, { description: "Missing operator", condition: customSignalCondition{ CustomSignalKey: customSignalKey, TargetCustomSignalValues: []string{premium}, }, expected: errInvalidCustomSignal, }, { description: "Missing key", condition: customSignalCondition{ CustomSignalOperator: stringExactlyMatches, TargetCustomSignalValues: []string{premium}, }, expected: errInvalidCustomSignal, }, { description: "Missing target values", condition: customSignalCondition{ CustomSignalOperator: stringExactlyMatches, CustomSignalKey: customSignalKey, }, expected: errInvalidCustomSignal, }, { description: "Missing multiple fields (operator and key)", condition: customSignalCondition{ TargetCustomSignalValues: []string{premium}, }, expected: errInvalidCustomSignal, }, { description: "Missing all fields", condition: customSignalCondition{}, expected: errInvalidCustomSignal, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { actual := tc.condition.isValid() if actual != tc.expected { t.Errorf("isValid() = %v, want %v for condition: %+v", actual, tc.expected, tc.condition) } }) } } func TestEvaluateCustomSignalCondition_MissingKeyInContext(t *testing.T) { condition := createNamedCondition(isEnabled, oneOfCondition{ CustomSignal: &customSignalCondition{ CustomSignalOperator: stringExactlyMatches, CustomSignalKey: customSignalKey, TargetCustomSignalValues: []string{premium}, }, }) // Context does NOT contain 'customSignalKey' context := map[string]any{ "key": "value", } evaluateConditionsAndReportResult(t, condition, isEnabled, context, false) } func TestCustomSignals_StringContains(t *testing.T) { testCases := []customSignalTestCase{ {actual: "testing", targets: "test,sting", outcome: true}, {actual: "check for spaces", targets: "for ,test", outcome: true}, {actual: "no word is present", targets: "not,absent,words", outcome: false}, {actual: "case Sensitive", targets: "Case,sensitive", outcome: false}, {actual: "match 'single quote'", targets: "'single quote',Match", outcome: true}, {actual: false, targets: "true, false", outcome: false}, {actual: false, targets: "true,false", outcome: true}, {actual: "no quote present", targets: "'no quote',\"present\"", outcome: false}, {actual: 123, targets: "23,string", outcome: true}, {actual: 123.45, targets: "9862123451,23.4", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCase(stringContains, t)(tc) } } func TestCustomSignals_StringDoesNotContain(t *testing.T) { testCases := []customSignalTestCase{ {actual: "foobar", targets: "foo,biz", outcome: false}, {actual: "foobar", targets: "biz,cat,car", outcome: true}, {actual: 387.42, targets: "6.4,54", outcome: true}, {actual: "single quote present", targets: "'single quote',Present ", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCase(stringDoesNotContain, t)(tc) } } func TestCustomSignals_StringExactlyMatches(t *testing.T) { testCases := []customSignalTestCase{ {actual: "foobar", targets: "foo,biz", outcome: false}, {actual: "Foobar", targets: " Foobar ,cat,car", outcome: true}, {actual: "matches if there are leading and trailing whitespaces", targets: " matches if there are leading and trailing whitespaces ", outcome: true}, {actual: "does not match internal whitespaces", targets: " does not match internal whitespaces ", outcome: false}, {actual: 123.456, targets: "123.45,456", outcome: false}, {actual: 987654321.1234567, targets: " 987654321.1234567 ,12", outcome: true}, {actual: "single quote present", targets: "'single quote',Present ", outcome: false}, {actual: true, targets: "true ", outcome: true}, {actual: struct { index int category string }{index: 1, category: "sample"}, targets: "{index: 1, category: \"sample\"}", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(stringExactlyMatches, t)(tc) runCustomSignalTestCase(stringExactlyMatches, t)(tc) } } func TestCustomSignals_StringContainsRegex(t *testing.T) { testCases := []customSignalTestCase{ {actual: "foobar", targets: "^foo,biz", outcome: true}, // Matches start anchor ^foo {actual: " hello world ", targets: " hello , world ", outcome: false}, // Patterns are literal strings including spaces, neither matches exactly? (Outcome seems unexpected for contains) {actual: "endswithhello", targets: ".*hello$", outcome: true}, // Matches end anchor hello$ {actual: "foobar", targets: "^foo", outcome: true}, // Starts with "foo" {actual: "barfoo", targets: "^foo", outcome: false}, // Does not start with "foo" {actual: "foobar", targets: "bar$", outcome: true}, // Ends with "bar" {actual: "barfoo", targets: "bar$", outcome: false}, // Does not end with "bar" {actual: "hello world", targets: "hello.*world", outcome: true}, // Contains "hello" and "world" with anything in between {actual: "hello world", targets: "hello\\s+world", outcome: true}, // Contains "hello" and "world" with one or more whitespace in between {actual: "helloworld", targets: "hello\\s+world", outcome: false}, // No whitespace between hello and world {actual: "123-456-7890", targets: "\\d{3}-\\d{3}-\\d{4}", outcome: true}, // Phone number format {actual: "invalid", targets: "([a-z]+", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCase(stringContainsRegex, t)(tc) } } func TestCustomSignals_NumericLessThan(t *testing.T) { withWhiteSpaces := []customSignalTestCase{ {actual: int16(2), targets: "4", outcome: true}, {actual: " -2.0 ", targets: " -2 ", outcome: false}, {actual: uint8(25), targets: "25.6", outcome: true}, {actual: float32(-25.5), targets: "-25.6", outcome: false}, {actual: " -25.5", targets: " -25.1 ", outcome: true}, {actual: " 3", targets: " 2,4 ", outcome: false}, {actual: "0", targets: "0", outcome: false}, } for _, tc := range withWhiteSpaces { runCustomSignalTestCaseWithWhiteSpaces(numericLessThan, t)(tc) } withoutWhiteSpaces := append(withWhiteSpaces, customSignalTestCase{actual: false, targets: "1", outcome: true}) for _, tc := range withoutWhiteSpaces { runCustomSignalTestCase(numericLessThan, t)(tc) } } func TestCustomSignals_NumericLessEqual(t *testing.T) { testCases := []customSignalTestCase{ {actual: int16(2), targets: "4", outcome: true}, {actual: "-2", targets: "-2", outcome: true}, {actual: float32(25.5), targets: "25.6", outcome: true}, {actual: -25.5, targets: "-25.6", outcome: false}, {actual: "-25.5", targets: "-25.1", outcome: true}, {actual: "0", targets: "0", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(numericLessThanEqual, t)(tc) runCustomSignalTestCase(numericLessThanEqual, t)(tc) } } func TestCustomSignals_NumericEqual(t *testing.T) { testCases := []customSignalTestCase{ {actual: float32(2), targets: "4", outcome: false}, {actual: "-2", targets: "-2", outcome: true}, {actual: -25.5, targets: "-25.6", outcome: false}, {actual: "-25.5", targets: "123a", outcome: false}, {actual: uint16(0), targets: "0", outcome: true}, {actual: struct { index int }{index: 2}, targets: "0", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(numericEqual, t)(tc) runCustomSignalTestCase(numericEqual, t)(tc) } } func TestCustomSignals_NumericNotEqual(t *testing.T) { testCases := []customSignalTestCase{ {actual: int16(-2), targets: "4", outcome: true}, {actual: "-2", targets: "-2", outcome: false}, {actual: float32(-25.5), targets: "-25.6", outcome: true}, {actual: "123a", targets: "-25.5", outcome: false}, {actual: "0", targets: "0", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(numericNotEqual, t)(tc) runCustomSignalTestCase(numericNotEqual, t)(tc) } } func TestCustomSignals_NumericGreaterThan(t *testing.T) { testCases := []customSignalTestCase{ {actual: float32(2), targets: "4", outcome: false}, {actual: "-2", targets: "-2", outcome: false}, {actual: 25.59, targets: "25.6", outcome: false}, {actual: int32(-25), targets: "-25.6", outcome: true}, {actual: "-25.5", targets: "-25.5", outcome: false}, {actual: "0", targets: "0", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(numericGreaterThan, t)(tc) runCustomSignalTestCase(numericGreaterThan, t)(tc) } } func TestCustomSignals_NumericGreaterEqual(t *testing.T) { testCases := []customSignalTestCase{ {actual: uint32(2), targets: "4", outcome: false}, {actual: "-2", targets: "-2", outcome: true}, {actual: float32(25.5), targets: "25.6", outcome: false}, {actual: -25.5, targets: "-25.6", outcome: true}, {actual: "-25.5", targets: "-25.5", outcome: true}, {actual: "0", targets: "0", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(numericGreaterEqual, t)(tc) runCustomSignalTestCase(numericGreaterEqual, t)(tc) } } func Test_TransformVersionToSegments(t *testing.T) { versionToSegmentTestCases := []struct { description string semanticVersion string outcome struct { err error segments []int } }{ { semanticVersion: "1.2.3.4.5", description: "Valid semantic version with maximum allowed segments", outcome: struct { err error segments []int }{ segments: []int{1, 2, 3, 4, 5}, }, }, { semanticVersion: "1.2.3.4.5.6", description: "Returns error when version exceeds maximum allowed segments", outcome: struct { err error segments []int }{ err: errTooManySegments, segments: nil, }, }, { semanticVersion: "1.2.3.4.-5", description: "Returns error when a segment is negative", outcome: struct { err error segments []int }{ err: errNegativeSegment, segments: nil, }, }, { semanticVersion: ".1.2.", description: "Handles leading/trailing separators and pads missing segments with zero", outcome: struct { err error segments []int }{ segments: []int{1, 2, 0, 0, 0}, }, }, { semanticVersion: "abcd.123", description: "Returns error for non-numeric segment value", outcome: struct { err error segments []int }{ err: strconv.ErrSyntax, segments: nil, }, }, } for _, tc := range versionToSegmentTestCases { t.Run(tc.description, func(t *testing.T) { t.Helper() segments, err := transformVersionToSegments(tc.semanticVersion) if !errors.Is(err, tc.outcome.err) { t.Fatalf("transformVersionToSegments(%q) error = %v, want %v", tc.semanticVersion, err, tc.outcome.err) } if !reflect.DeepEqual(tc.outcome.segments, segments) { t.Errorf("transformVersionToSegments(%q) segments = %v, want %v", tc.semanticVersion, segments, tc.outcome.segments) } }) } } func TestCustomSignals_SemanticVersionLessThan(t *testing.T) { // a semantic version with leading or trailing segment separators cannot be entered on the console testCases := []customSignalTestCase{ {actual: uint16(2), targets: "4", outcome: true}, {actual: 2., targets: "4.0", outcome: true}, {actual: .9, targets: "0.4", outcome: false}, {actual: ".3", targets: "0.1", outcome: false}, {actual: float32(2.3), targets: "2.3.2", outcome: true}, {actual: "2.3.4.1", targets: "2.3.4", outcome: false}, {actual: 2.3, targets: "2.3.0", outcome: false}, {actual: int16(3), targets: "1.2,4", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(semanticVersionLessThan, t)(tc) runCustomSignalTestCase(semanticVersionLessThan, t)(tc) } } func TestCustomSignals_SemanticVersionLessEqual(t *testing.T) { // a semantic version with leading or trailing segment separators cannot be entered on the console testCases := []customSignalTestCase{ {actual: 2., targets: "2.0", outcome: true}, {actual: .456, targets: "0.456.13", outcome: true}, {actual: ".3", targets: "0.1,0.4", outcome: false}, {actual: float32(2.3), targets: "2.3.0", outcome: true}, {actual: "2.3.4.5.6", targets: "2.3.4.5.6", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(semanticVersionLessEqual, t)(tc) runCustomSignalTestCase(semanticVersionLessEqual, t)(tc) } } func TestCustomSignals_SemanticVersionEqual(t *testing.T) { // a semantic version with leading or trailing segment separators cannot be entered on the console testCases := []customSignalTestCase{ {actual: 2., targets: "2.0", outcome: true}, {actual: 2.0, targets: "2", outcome: true}, {actual: uint16(2), targets: "2", outcome: true}, {actual: ".3", targets: "0.1, 0.4", outcome: false}, {actual: "1.2.3.4.5.6", targets: "1.2.3", outcome: false}, {actual: float32(2.3), targets: "2.3.0", outcome: true}, {actual: "2.3.4.5.6", targets: "2.3.4.5.6", outcome: true}, {actual: "1.3.4.5.6", targets: "2.3.4.5.6", outcome: false}, {actual: "5.12.-3.4", targets: "5.12.3.4", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(semanticVersionEqual, t)(tc) runCustomSignalTestCase(semanticVersionEqual, t)(tc) } } func TestCustomSignals_SemanticVersionNotEqual(t *testing.T) { // a semantic version with leading or trailing segment separators cannot be entered on the console testCases := []customSignalTestCase{ {actual: 2.3, targets: "2.0", outcome: true}, {actual: uint32(8), targets: "2", outcome: true}, {actual: "1.2.3.4.5.6", targets: "1.2.3", outcome: false}, {actual: "2.3.4.5.6", targets: "2.3.4.5.6", outcome: false}, {actual: "5.12.-3.4", targets: "5.12.3.4", outcome: false}, {actual: "1.2.3", targets: "1.2.a", outcome: false}, {actual: struct{}{}, targets: "1", outcome: false}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(semanticVersionNotEqual, t)(tc) runCustomSignalTestCase(semanticVersionNotEqual, t)(tc) } } func TestCustomSignals_SemanticVersionGreaterThan(t *testing.T) { // a semantic version with leading or trailing segment separators cannot be entered on the console testCases := []customSignalTestCase{ {actual: 2., targets: "2.0", outcome: false}, {actual: 2.0, targets: "2", outcome: false}, {actual: ".3", targets: "0.1", outcome: true}, {actual: "1.2.3.4.5.6", targets: "1.2.3", outcome: false}, {actual: 12.4, targets: "12.3.0", outcome: true}, {actual: "2.3.4.5.6", targets: "2.3.4.5.6", outcome: false}, {actual: "5.12.3.4", targets: "5.11.8.9", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(semanticVersionGreaterThan, t)(tc) runCustomSignalTestCase(semanticVersionGreaterThan, t)(tc) } } func TestCustomSignals_SemanticVersionGreaterEqual(t *testing.T) { // a semantic version with leading or trailing segment separators cannot be entered on the console testCases := []customSignalTestCase{ {actual: 2., targets: "2.0", outcome: true}, {actual: int16(2), targets: "2", outcome: true}, {actual: ".3", targets: "0.1", outcome: true}, {actual: "1.2.3.4.5.6", targets: "1.2.3", outcome: false}, {actual: float32(12.4), targets: "12.3.0", outcome: true}, {actual: "2.3.4.5.6", targets: "2.3.4.5.6", outcome: true}, {actual: "5.12.3.4", targets: "5.11.8.9", outcome: true}, } for _, tc := range testCases { runCustomSignalTestCaseWithWhiteSpaces(semanticVersionGreaterEqual, t)(tc) runCustomSignalTestCase(semanticVersionGreaterEqual, t)(tc) } } golang-google-firebase-go-4.18.0/remoteconfig/remoteconfig.go000066400000000000000000000065601505612111400242020ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package remoteconfig provides functions to fetch and evaluate a server-side Remote Config template. package remoteconfig import ( "context" "encoding/json" "errors" "fmt" "firebase.google.com/go/v4/internal" ) const ( defaultBaseURL = "https://firebaseremoteconfig.googleapis.com" firebaseClientHeader = "X-Firebase-Client" ) // Client is the interface for the Remote Config Cloud service. type Client struct { *rcClient } // NewClient initializes a RemoteConfigClient with app-specific detail and returns a // client to be used by the user. func NewClient(ctx context.Context, c *internal.RemoteConfigClientConfig) (*Client, error) { if c.ProjectID == "" { return nil, errors.New("project ID is required to access Remote Conifg") } hc, _, err := internal.NewHTTPClient(ctx, c.Opts...) if err != nil { return nil, err } return &Client{ rcClient: newRcClient(hc, c), }, nil } // RemoteConfigClient facilitates requests to the Firebase Remote Config backend. type rcClient struct { httpClient *internal.HTTPClient project string rcBaseURL string version string } func newRcClient(client *internal.HTTPClient, conf *internal.RemoteConfigClientConfig) *rcClient { version := fmt.Sprintf("fire-admin-go/%s", conf.Version) client.Opts = []internal.HTTPOption{ internal.WithHeader(firebaseClientHeader, version), internal.WithHeader("X-Firebase-ETag", "true"), internal.WithHeader("x-goog-api-client", internal.GetMetricsHeader(conf.Version)), } // Handles errors for non-success HTTP status codes from Remote Config servers. client.CreateErrFn = handleRemoteConfigError return &rcClient{ rcBaseURL: defaultBaseURL, project: conf.ProjectID, version: version, httpClient: client, } } // GetServerTemplate initializes a new ServerTemplate instance and fetches the server template. func (c *rcClient) GetServerTemplate(ctx context.Context, defaultConfig map[string]any) (*ServerTemplate, error) { template, err := c.InitServerTemplate(defaultConfig, "") if err != nil { return nil, err } err = template.Load(ctx) return template, err } // InitServerTemplate initializes a new ServerTemplate with the default config and // an optional template data json. func (c *rcClient) InitServerTemplate(defaultConfig map[string]any, templateDataJSON string) (*ServerTemplate, error) { template, err := newServerTemplate(c, defaultConfig) if templateDataJSON != "" && err == nil { err = template.Set(templateDataJSON) } return template, err } func handleRemoteConfigError(resp *internal.Response) error { err := internal.NewFirebaseError(resp) var p struct { Error string `json:"error"` } json.Unmarshal(resp.Body, &p) if p.Error != "" { err.String = fmt.Sprintf("http error status: %d; reason: %s", resp.Status, p.Error) } return err } golang-google-firebase-go-4.18.0/remoteconfig/remoteconfig_test.go000066400000000000000000000030461505612111400252350ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig import ( "context" "testing" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" ) var ( client *Client testOpts = []option.ClientOption{ option.WithTokenSource(&internal.MockTokenSource{AccessToken: "mock-token"}), } ) // Test NewClient with valid config func TestNewClientSuccess(t *testing.T) { ctx := context.Background() config := &internal.RemoteConfigClientConfig{ ProjectID: "test-project", Opts: testOpts, Version: "1.2.3", } client, err := NewClient(ctx, config) if err != nil { t.Fatalf("NewClient failed: %v", err) } if client == nil { t.Error("NewClient returned nil client") } } // Test NewClient with missing Project ID func TestNewClientMissingProjectID(t *testing.T) { ctx := context.Background() config := &internal.RemoteConfigClientConfig{} _, err := NewClient(ctx, config) if err == nil { t.Fatal("NewClient should have failed with missing project ID") } } golang-google-firebase-go-4.18.0/remoteconfig/server_config.go000066400000000000000000000102641505612111400243500ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig import ( "slices" "strconv" "strings" ) // ValueSource represents the source of a value. type ValueSource int // Constants for value source. const ( sourceUnspecified ValueSource = iota Static // Static represents a statically defined value. Remote // Remote represents a value fetched from a remote source. Default // Default represents a default value. ) // Value defines the interface for configuration values. type value struct { source ValueSource value string } // Default values for different parameter types. const ( DefaultValueForBoolean = false DefaultValueForString = "" DefaultValueForNumber = 0 ) var booleanTruthyValues = []string{"1", "true", "t", "yes", "y", "on"} // ServerConfig is the implementation of the ServerConfig interface. type ServerConfig struct { configValues map[string]value } // NewServerConfig creates a new ServerConfig instance. func newServerConfig(configValues map[string]value) *ServerConfig { return &ServerConfig{configValues: configValues} } // GetBoolean returns the boolean value associated with the given key. // // It returns true if the string value is "1", "true", "t", "yes", "y", or "on" (case-insensitive). // Otherwise, or if the key is not found, it returns the default boolean value (false). func (s *ServerConfig) GetBoolean(key string) bool { return s.getValue(key).asBoolean() } // GetInt returns the integer value associated with the given key. // // If the parameter value cannot be parsed as an integer, or if the key is not found, // it returns the default numeric value (0). func (s *ServerConfig) GetInt(key string) int { return s.getValue(key).asInt() } // GetFloat returns the float value associated with the given key. // // If the parameter value cannot be parsed as a float64, or if the key is not found, // it returns the default float value (0). func (s *ServerConfig) GetFloat(key string) float64 { return s.getValue(key).asFloat() } // GetString returns the string value associated with the given key. // // If the key is not found, it returns the default string value (""). func (s *ServerConfig) GetString(key string) string { return s.getValue(key).asString() } // GetValueSource returns the source of the value. func (s *ServerConfig) GetValueSource(key string) ValueSource { return s.getValue(key).source } // getValue returns the value associated with the given key. func (s *ServerConfig) getValue(key string) *value { if val, ok := s.configValues[key]; ok { return &val } return newValue(Static, DefaultValueForString) } // newValue creates a new value instance. func newValue(source ValueSource, customValue string) *value { if customValue == "" { customValue = DefaultValueForString } return &value{source: source, value: customValue} } // asString returns the value as a string. func (v *value) asString() string { return v.value } // asBoolean returns the value as a boolean. func (v *value) asBoolean() bool { if v.source == Static { return DefaultValueForBoolean } return slices.Contains(booleanTruthyValues, strings.ToLower(v.value)) } // asInt returns the value as an integer. func (v *value) asInt() int { if v.source == Static { return DefaultValueForNumber } num, err := strconv.Atoi(v.value) if err != nil { return DefaultValueForNumber } return num } // asFloat returns the value as a float. func (v *value) asFloat() float64 { if v.source == Static { return DefaultValueForNumber } num, err := strconv.ParseFloat(v.value, doublePrecision) if err != nil { return DefaultValueForNumber } return num } golang-google-firebase-go-4.18.0/remoteconfig/server_config_test.go000066400000000000000000000052051505612111400254060ustar00rootroot00000000000000package remoteconfig import "testing" type configGetterTestCase struct { name string key string expectedString string expectedInt int expectedBool bool expectedFloat float64 expectedSource ValueSource } func getTestConfig() ServerConfig { config := ServerConfig{ configValues: map[string]value{ paramOne: { value: valueOne, source: Default, }, paramTwo: { value: valueTwo, source: Remote, }, paramThree: { value: valueThree, source: Default, }, paramFour: { value: valueFour, source: Remote, }, }, } return config } func TestServerConfigGetters(t *testing.T) { config := getTestConfig() testCases := []configGetterTestCase{ { name: "Parameter Value : String, Default Source", key: paramOne, expectedString: valueOne, expectedInt: 0, expectedBool: false, expectedFloat: 0, expectedSource: Default, }, { name: "Parameter Value : JSON, Remote Source", key: paramTwo, expectedString: valueTwo, expectedInt: 0, expectedBool: false, expectedFloat: 0, expectedSource: Remote, }, { name: "Unknown Parameter Value", key: "unknown_param", expectedString: "", expectedInt: 0, expectedBool: false, expectedFloat: 0, expectedSource: Static, }, { name: "Parameter Value - Float, Default Source", key: paramThree, expectedString: "123456789.123", expectedInt: 0, expectedBool: false, expectedFloat: 123456789.123, expectedSource: Default, }, { name: "Parameter Value - Boolean, Remote Source", key: paramFour, expectedString: "1", expectedInt: 1, expectedBool: true, expectedFloat: 1, expectedSource: Remote, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { if got := config.GetString(tc.key); got != tc.expectedString { t.Errorf("GetString(%q): got %q, want %q", tc.key, got, tc.expectedString) } if got := config.GetInt(tc.key); got != tc.expectedInt { t.Errorf("GetInt(%q): got %d, want %d", tc.key, got, tc.expectedInt) } if got := config.GetBoolean(tc.key); got != tc.expectedBool { t.Errorf("GetBoolean(%q): got %t, want %t", tc.key, got, tc.expectedBool) } if got := config.GetFloat(tc.key); got != tc.expectedFloat { t.Errorf("GetFloat(%q): got %f, want %f", tc.key, got, tc.expectedFloat) } if got := config.GetValueSource(tc.key); got != tc.expectedSource { t.Errorf("GetValueSource(%q): got %v, want %v", tc.key, got, tc.expectedSource) } }) } } golang-google-firebase-go-4.18.0/remoteconfig/server_template.go000066400000000000000000000142501505612111400247150ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig import ( "context" "encoding/json" "errors" "fmt" "log" "net/http" "sync/atomic" "firebase.google.com/go/v4/internal" ) // serverTemplateData stores the internal representation of the server template. type serverTemplateData struct { // A list of conditions in descending order by priority. Parameters map[string]parameter `json:"parameters,omitempty"` // Map of parameter keys to their optional default values and optional conditional values. Conditions []namedCondition `json:"conditions,omitempty"` // Version information for the current Remote Config template. Version *version `json:"version,omitempty"` // Current Remote Config template ETag. ETag string `json:"etag"` } // ServerTemplate represents a template with configuration data, cache, and service information. type ServerTemplate struct { rcClient *rcClient cache atomic.Pointer[serverTemplateData] stringifiedDefaultConfig map[string]string } // newServerTemplate initializes a new ServerTemplate with optional default configuration. func newServerTemplate(rcClient *rcClient, defaultConfig map[string]any) (*ServerTemplate, error) { stringifiedConfig := make(map[string]string, len(defaultConfig)) // Pre-allocate map for key, value := range defaultConfig { if value == nil { stringifiedConfig[key] = "" continue } if stringVal, ok := value.(string); ok { stringifiedConfig[key] = stringVal continue } // Marshal the value to JSON bytes. jsonBytes, err := json.Marshal(value) if err != nil { return nil, fmt.Errorf("unable to stringify default value for parameter '%s': %w", key, err) } stringifiedConfig[key] = string(jsonBytes) } return &ServerTemplate{ rcClient: rcClient, stringifiedDefaultConfig: stringifiedConfig, }, nil } // Load fetches the server template data from the remote config service and caches it. func (s *ServerTemplate) Load(ctx context.Context) error { request := &internal.Request{ Method: http.MethodGet, URL: fmt.Sprintf("%s/v1/projects/%s/namespaces/firebase-server/serverRemoteConfig", s.rcClient.rcBaseURL, s.rcClient.project), } templateData := new(serverTemplateData) response, err := s.rcClient.httpClient.DoAndUnmarshal(ctx, request, &templateData) if err != nil { return err } templateData.ETag = response.Header.Get("etag") s.cache.Store(templateData) return nil } // Set initializes a template using a server template JSON. func (s *ServerTemplate) Set(templateDataJSON string) error { templateData := new(serverTemplateData) if err := json.Unmarshal([]byte(templateDataJSON), &templateData); err != nil { return fmt.Errorf("error while parsing server template: %v", err) } s.cache.Store(templateData) return nil } // ToJSON returns a json representing the cached serverTemplateData. func (s *ServerTemplate) ToJSON() (string, error) { jsonServerTemplate, err := json.Marshal(s.cache.Load()) if err != nil { return "", fmt.Errorf("error while parsing server template: %v", err) } return string(jsonServerTemplate), nil } // Evaluate and processes the cached template data. func (s *ServerTemplate) Evaluate(context map[string]any) (*ServerConfig, error) { if s.cache.Load() == nil { return &ServerConfig{}, errors.New("no Remote Config Server template in Cache, call Load() before calling Evaluate()") } config := make(map[string]value) // Initialize config with in-app default values. for key, inAppDefault := range s.stringifiedDefaultConfig { config[key] = value{source: Default, value: inAppDefault} } usedConditions := s.cache.Load().filterUsedConditions() ce := conditionEvaluator{ conditions: usedConditions, evaluationContext: context, } evaluatedConditions := ce.evaluateConditions() // Overlays config value objects derived by evaluating the template. for key, parameter := range s.cache.Load().Parameters { var paramValueWrapper parameterValue var matchedConditionName string // Iterate through used conditions in decreasing priority order. for _, condition := range usedConditions { if value, ok := parameter.ConditionalValues[condition.Name]; ok && evaluatedConditions[condition.Name] { paramValueWrapper = value matchedConditionName = condition.Name break } } if paramValueWrapper.UseInAppDefault != nil && *paramValueWrapper.UseInAppDefault { log.Printf("Parameter '%s': Condition '%s' uses in-app default.\n", key, matchedConditionName) } else if paramValueWrapper.Value != nil { config[key] = value{source: Remote, value: *paramValueWrapper.Value} } else if parameter.DefaultValue.UseInAppDefault != nil && *parameter.DefaultValue.UseInAppDefault { log.Printf("Parameter '%s': Using parameter's in-app default.\n", key) } else if parameter.DefaultValue.Value != nil { config[key] = value{source: Remote, value: *parameter.DefaultValue.Value} } } return newServerConfig(config), nil } // filterUsedConditions identifies conditions that are referenced by parameters and returns them in order of decreasing priority. func (s *serverTemplateData) filterUsedConditions() []namedCondition { usedConditionNames := make(map[string]struct{}) for _, parameter := range s.Parameters { for name := range parameter.ConditionalValues { usedConditionNames[name] = struct{}{} } } // Filter the original conditions list, preserving order. conditionsToEvaluate := make([]namedCondition, 0, len(usedConditionNames)) for _, condition := range s.Conditions { if _, ok := usedConditionNames[condition.Name]; ok { conditionsToEvaluate = append(conditionsToEvaluate, condition) } } return conditionsToEvaluate } golang-google-firebase-go-4.18.0/remoteconfig/server_template_test.go000066400000000000000000000325361505612111400257630ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig import ( "testing" ) const ( paramOne = "test_param_one" paramTwo = "test_param_two" paramThree = "test_param_three" paramFour = "test_param_four" paramFive = "test_param_five" valueOne = "test_value_one" valueTwo = "{\"test\" : \"value\"}" valueThree = "123456789.123" valueFour = "1" conditionOne = "test_condition_one" conditionTwo = "test_condition_two" customSignalKeyOne = "custom_signal_key_one" testEtag = "test-etag" testVersion = "test-version" ) // Test newServerTemplate with valid default config func TestNewServerTemplateStringifiesDefaults(t *testing.T) { defaultConfig := map[string]any{ paramOne: "value1", paramTwo: 123, paramThree: true, paramFour: nil, paramFive: "{\"test_param\" : \"test_value\"}", } expectedStringified := map[string]string{ paramOne: "value1", paramTwo: "123", paramThree: "true", paramFour: "", // nil becomes empty string paramFive: "{\"test_param\" : \"test_value\"}", } rcClient := &rcClient{} template, err := newServerTemplate(rcClient, defaultConfig) if err != nil { t.Fatalf("newServerTemplate() error = %v", err) } if template == nil { t.Fatal("newServerTemplate() returned nil template") } if len(template.stringifiedDefaultConfig) != len(defaultConfig) { t.Errorf("len(stringifiedDefaultConfig) = %d, want %d", len(template.stringifiedDefaultConfig), len(expectedStringified)) } for key, expectedValue := range expectedStringified { t.Run(key, func(t *testing.T) { actualValue, ok := template.stringifiedDefaultConfig[key] if !ok { t.Errorf("Key %q not found in stringifiedDefaultConfig", key) } else if actualValue != expectedValue { t.Errorf("stringifiedDefaultConfig[%q] = %q, want %q", key, actualValue, expectedValue) } }) } } // Test ServerTemplate.Set with valid JSON func TestServerTemplateSetSuccess(t *testing.T) { template := &ServerTemplate{} json := `{"conditions": [{"name": "percent_condition", "condition": {"orCondition": {"conditions": [{"andCondition": {"conditions": [{"percent": {"percentOperator": "BETWEEN", "seed": "fb4aczak670h", "microPercentRange": {"microPercentUpperBound": 34000000}}}]}}]}}}, {"name": "percent_2", "condition": {"orCondition": {"conditions": [{"andCondition": {"conditions": [{"percent": {"percentOperator": "BETWEEN", "seed": "yxmb9v8fafxg", "microPercentRange": {"microPercentLowerBound": 12000000, "microPercentUpperBound": 100000000}}}, {"customSignal": {"customSignalOperator": "STRING_CONTAINS", "customSignalKey": "test", "targetCustomSignalValues": ["hello"]}}]}}]}}}], "parameters": {"test": {"defaultValue": {"useInAppDefault": true}, "conditionalValues": {"percent_condition": {"value": "{\"condition\" : \"percent\"}"}}}}, "version": {"versionNumber": "266", "isLegacy": true}, "etag": "test_etag"}` err := template.Set(json) if err != nil { t.Fatalf("ServerTemplate.Set failed: %v", err) } if template.cache.Load() == nil { t.Fatal("ServerTemplate.Set did not store data in cache") } } // Test ServerTemplate.ToJSON with valid data func TestServerTemplateToJSONSuccess(t *testing.T) { template := &ServerTemplate{} value := "test_value_one" // The raw string value data := &serverTemplateData{ Parameters: map[string]parameter{ paramOne: { DefaultValue: parameterValue{ Value: &value, }, }, }, Version: &version{ VersionNumber: testVersion, IsLegacy: true, }, ETag: testEtag, } template.cache.Store(data) json, err := template.ToJSON() if err != nil { t.Fatalf("ServerTemplate.ToJSON failed: %v", err) } expectedJSON := `{"parameters":{"test_param_one":{"defaultValue":{"value":"test_value_one"}}},"version":{"versionNumber":"test-version","isLegacy":true},"etag":"test-etag"}` if json != expectedJSON { t.Fatalf("ServerTemplate.ToJSON returned incorrect json: %v want %v", json, expectedJSON) } } func TestServerTemplateReturnsDefaultFromRemote(t *testing.T) { paramVal := valueOne template := &ServerTemplate{} data := &serverTemplateData{ Parameters: map[string]parameter{ paramOne: { DefaultValue: parameterValue{ Value: ¶mVal, }, }, }, Version: &version{ VersionNumber: testVersion, }, ETag: testEtag, } template.cache.Store(data) context := make(map[string]any) config, err := template.Evaluate(context) if err != nil { t.Fatalf("Error in evaluating template %v", err) } if config == nil { t.Fatal("ServerTemplate.Evaluate returned nil config") } val := config.GetString(paramOne) src := config.GetValueSource(paramOne) if val != valueOne { t.Fatalf("ServerTemplate.Evaluate returned incorrect value: %v want %v", val, valueOne) } if src != Remote { t.Fatalf("ServerTemplate.Evaluate returned incorrect source: %v want %v", src, Remote) } } func TestEvaluateReturnsInAppDefault(t *testing.T) { booleanTrue := true td := &serverTemplateData{ Parameters: map[string]parameter{ paramOne: { DefaultValue: parameterValue{ UseInAppDefault: &booleanTrue, }, }, }, Version: &version{ VersionNumber: testVersion, }, ETag: testEtag, } testCases := []struct { name string stringifiedDefaultConfig map[string]string expectedValue string expectedSource ValueSource }{ { name: "No In-App Default Provided", stringifiedDefaultConfig: map[string]string{}, expectedValue: "", expectedSource: Static, }, { name: "In-App Default Provided", stringifiedDefaultConfig: map[string]string{paramOne: valueOne}, expectedValue: valueOne, expectedSource: Default, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { st := ServerTemplate{ stringifiedDefaultConfig: tc.stringifiedDefaultConfig, } st.cache.Store(td) config, err := st.Evaluate(map[string]any{}) if err != nil { t.Fatalf("Evaluate() error = %v", err) } if config == nil { t.Fatal("Evaluate() returned nil config") } val := config.GetString(paramOne) src := config.GetValueSource(paramOne) if val != tc.expectedValue { t.Errorf("GetString(%q) = %q, want %q", paramOne, val, tc.expectedValue) } if src != tc.expectedSource { t.Errorf("GetValueSource(%q) = %v, want %v", paramOne, src, tc.expectedSource) } }) } } func TestEvaluate_WithACondition_ReturnsConditionalRemoteValue(t *testing.T) { vOne := valueOne vTwo := valueTwo template := &ServerTemplate{} data := &serverTemplateData{ Parameters: map[string]parameter{ paramOne: { DefaultValue: parameterValue{ Value: &vOne, }, ConditionalValues: map[string]parameterValue{ conditionOne: { Value: &vTwo, }, }, }, }, Conditions: []namedCondition{ { Name: conditionOne, Condition: &oneOfCondition{ OrCondition: &orCondition{ Conditions: []oneOfCondition{ { Percent: &percentCondition{ PercentOperator: between, Seed: testSeed, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 0, MicroPercentUpperBound: totalMicroPercentiles, // upper bound is set to the max; the percent condition will always evaluate to true }, }, }, }, }, }, }, }, Version: &version{ VersionNumber: testVersion, }, ETag: testEtag, } template.cache.Store(data) context := map[string]any{randomizationID: testRandomizationID} config, err := template.Evaluate(context) if err != nil { t.Fatalf("Error in evaluating template %v", err) } if config == nil { t.Fatal("ServerTemplate.Evaluate returned nil config") } val := config.GetString(paramOne) src := config.GetValueSource(paramOne) if val != vTwo { t.Fatalf("ServerTemplate.Evaluate returned incorrect value: %v want %v", val, vTwo) } if src != Remote { t.Fatalf("ServerTemplate.Evaluate returned incorrect source: %v want %v", src, Remote) } } func TestEvaluate_WithACondition_ReturnsConditionalInAppDefaultValue(t *testing.T) { vOne := valueOne boolTrue := true template := &ServerTemplate{ stringifiedDefaultConfig: map[string]string{paramOne: valueThree}, } data := &serverTemplateData{ Parameters: map[string]parameter{ paramOne: { DefaultValue: parameterValue{ Value: &vOne, }, ConditionalValues: map[string]parameterValue{ conditionOne: { UseInAppDefault: &boolTrue, }, }, }, }, Conditions: []namedCondition{ { Name: conditionOne, Condition: &oneOfCondition{ OrCondition: &orCondition{ Conditions: []oneOfCondition{ { AndCondition: &andCondition{ Conditions: []oneOfCondition{ { Percent: &percentCondition{ PercentOperator: between, Seed: testSeed, MicroPercentRange: microPercentRange{ MicroPercentLowerBound: 0, MicroPercentUpperBound: totalMicroPercentiles, }, }, }, { CustomSignal: &customSignalCondition{ CustomSignalKey: customSignalKeyOne, CustomSignalOperator: stringExactlyMatches, TargetCustomSignalValues: []string{valueTwo}, }, }, }, }, }, }, }, }, }, }, Version: &version{ VersionNumber: testVersion, }, ETag: testEtag, } template.cache.Store(data) context := map[string]any{randomizationID: testRandomizationID, customSignalKeyOne: valueTwo} config, err := template.Evaluate(context) if err != nil { t.Fatalf("Error in evaluating template %v", err) } if config == nil { t.Fatal("ServerTemplate.Evaluate returned nil config") } val := config.GetString(paramOne) src := config.GetValueSource(paramOne) if val != valueThree { t.Fatalf("ServerTemplate.Evaluate returned incorrect value: %v want %v", val, valueThree) } if src != Default { t.Fatalf("ServerTemplate.Evaluate returned incorrect source: %v want %v", src, Default) } } func TestGetUsedConditions(t *testing.T) { ncOne := namedCondition{Name: "ncOne"} ncTwo := namedCondition{Name: "ncTwo"} ncThree := namedCondition{Name: "ncThree"} paramVal := valueOne testCases := []struct { name string data *serverTemplateData expectedConditions []namedCondition }{ { name: "No parameters, no conditions", data: &serverTemplateData{}, expectedConditions: []namedCondition{}, }, { name: "Parameters, but no conditions", data: &serverTemplateData{ Parameters: map[string]parameter{ paramOne: {DefaultValue: parameterValue{Value: ¶mVal}}, }, }, expectedConditions: []namedCondition{}, }, { name: "Conditions, but no parameters", data: &serverTemplateData{ Conditions: []namedCondition{ncOne, ncTwo}, }, expectedConditions: []namedCondition{}, }, { name: "Conditions, but parameters use no conditional values", data: &serverTemplateData{ Parameters: map[string]parameter{ paramOne: {DefaultValue: parameterValue{Value: ¶mVal}}, }, Conditions: []namedCondition{ncOne, ncTwo}, }, expectedConditions: []namedCondition{}, }, { name: "One parameter uses one condition", data: &serverTemplateData{ Parameters: map[string]parameter{ paramOne: {ConditionalValues: map[string]parameterValue{"ncOne": {Value: ¶mVal}}}, }, Conditions: []namedCondition{ncOne, ncTwo}, }, expectedConditions: []namedCondition{ncOne}, }, { name: "One parameter uses multiple conditions", data: &serverTemplateData{ Parameters: map[string]parameter{ paramOne: {ConditionalValues: map[string]parameterValue{ "ncOne": {Value: ¶mVal}, "ncThree": {Value: ¶mVal}, }}, }, Conditions: []namedCondition{ncOne, ncTwo, ncThree}, }, expectedConditions: []namedCondition{ncOne, ncThree}, }, { name: "Multiple parameters use overlapping conditions", data: &serverTemplateData{ Parameters: map[string]parameter{ paramOne: {ConditionalValues: map[string]parameterValue{"ncTwo": {Value: ¶mVal}}}, paramTwo: {ConditionalValues: map[string]parameterValue{"ncOne": {Value: ¶mVal}, "ncTwo": {Value: ¶mVal}}}, }, Conditions: []namedCondition{ncTwo, ncThree, ncOne}, }, expectedConditions: []namedCondition{ncTwo, ncOne}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { used := tc.data.filterUsedConditions() if len(used) != len(tc.expectedConditions) { t.Fatalf("filterUsedConditions() returned %d conditions, want %d", len(used), len(tc.expectedConditions)) } for idx, ec := range tc.expectedConditions { if used[idx].Name != ec.Name { t.Errorf("Condition at index %d has name %q, want %q", idx, used[idx].Name, ec.Name) } } }) } } golang-google-firebase-go-4.18.0/remoteconfig/server_template_types.go000066400000000000000000000163171505612111400261470ustar00rootroot00000000000000// Copyright 2025 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package remoteconfig // Represents a Remote Config condition in the dataplane. // A condition targets a specific group of users. A list of these conditions // comprises part of a Remote Config template. type namedCondition struct { // A non-empty and unique name of this condition. Name string `json:"name,omitempty"` // The logic of this condition. // See the documentation on https://firebase.google.com/docs/remote-config/condition-reference // for the expected syntax of this field. Condition *oneOfCondition `json:"condition,omitempty"` } // Represents a condition that may be one of several types. // Only the first defined field will be processed. type oneOfCondition struct { // Makes this condition an OR condition. OrCondition *orCondition `json:"orCondition,omitempty"` // Makes this condition an AND condition. AndCondition *andCondition `json:"andCondition,omitempty"` // Makes this condition a percent condition. Percent *percentCondition `json:"percent,omitempty"` // Makes this condition a custom signal condition. CustomSignal *customSignalCondition `json:"customSignal,omitempty"` // Added for the purpose of testing. Boolean *bool `json:"boolean,omitempty"` } // Represents a collection of conditions that evaluate to true if any are true. type orCondition struct { Conditions []oneOfCondition `json:"conditions,omitempty"` } // Represents a collection of conditions that evaluate to true if all are true. type andCondition struct { Conditions []oneOfCondition `json:"conditions,omitempty"` } // Represents a condition that compares the instance pseudo-random percentile to a given limit. type percentCondition struct { // The choice of percent operator to determine how to compare targets to percent(s). PercentOperator string `json:"percentOperator,omitempty"` // The seed used when evaluating the hash function to map an instance to // a value in the hash space. This is a string which can have 0 - 32 // characters and can contain ASCII characters [-_.0-9a-zA-Z].The string is case-sensitive. Seed string `json:"seed,omitempty"` // The limit of percentiles to target in micro-percents when // using the LESS_OR_EQUAL and GREATER_THAN operators. The value must // be in the range [0 and 100_000_000]. MicroPercent uint32 `json:"microPercent,omitempty"` // The micro-percent interval to be used with the BETWEEN operator. MicroPercentRange microPercentRange `json:"microPercentRange,omitempty"` } // Represents the limit of percentiles to target in micro-percents. // The value must be in the range [0 and 100_000_000]. type microPercentRange struct { // The lower limit of percentiles to target in micro-percents. // The value must be in the range [0 and 100_000_000]. MicroPercentLowerBound uint32 `json:"microPercentLowerBound"` // The upper limit of percentiles to target in micro-percents. // The value must be in the range [0 and 100_000_000]. MicroPercentUpperBound uint32 `json:"microPercentUpperBound"` } // Represents a condition that compares provided signals against a target value. type customSignalCondition struct { // The choice of custom signal operator to determine how to compare targets // to value(s). CustomSignalOperator string `json:"customSignalOperator,omitempty"` // The key of the signal set in the EvaluationContext. CustomSignalKey string `json:"customSignalKey,omitempty"` // A list of at most 100 target custom signal values. For numeric and semantic version operators, this will have exactly ONE target value. TargetCustomSignalValues []string `json:"targetCustomSignalValues,omitempty"` } // Structure representing a Remote Config parameter. // At minimum, a `defaultValue` or a `conditionalValues` entry must be present for the parameter to have any effect. type parameter struct { // The value to set the parameter to, when none of the named conditions evaluate to `true`. DefaultValue parameterValue `json:"defaultValue,omitempty"` // A `(condition name, value)` map. The condition name of the highest priority // (the one listed first in the Remote Config template's conditions list) determines the value of this parameter. ConditionalValues map[string]parameterValue `json:"conditionalValues,omitempty"` // A description for this parameter. Should not be over 100 characters and may contain any Unicode characters. Description string `json:"description,omitempty"` // The data type for all values of this parameter in the current version of the template. // It can be a string, number, boolean or JSON, and defaults to type string if unspecified. ValueType string `json:"valueType,omitempty"` } // Represents a Remote Config parameter value // that could be either an explicit parameter value or an in-app default value. type parameterValue struct { // The `string` value that the parameter is set to when it is an explicit parameter value. Value *string `json:"value,omitempty"` // If true, indicates that the in-app default value is to be used for the parameter. UseInAppDefault *bool `json:"useInAppDefault,omitempty"` } // Structure representing a Remote Config template version. // Output only, except for the version description. Contains metadata about a particular // version of the Remote Config template. All fields are set at the time the specified Remote Config template is published. type version struct { // The version number of a Remote Config template. VersionNumber string `json:"versionNumber,omitempty"` // The timestamp of when this version of the Remote Config template was written to the // Remote Config backend. UpdateTime string `json:"updateTime,omitempty"` // The origin of the template update action. UpdateOrigin string `json:"updateOrigin,omitempty"` // The type of the template update action. UpdateType string `json:"updateType,omitempty"` // Aggregation of all metadata fields about the account that performed the update. UpdateUser *remoteConfigUser `json:"updateUser,omitempty"` // The user-provided description of the corresponding Remote Config template. Description string `json:"description,omitempty"` // The version number of the Remote Config template that has become the current version // due to a rollback. Only present if this version is the result of a rollback. RollbackSource string `json:"rollbackSource,omitempty"` // Indicates whether this Remote Config template was published before version history was supported. IsLegacy bool `json:"isLegacy,omitempty"` } // Represents a Remote Config user. type remoteConfigUser struct { // Email address. Output only. Email string `json:"email,omitempty"` // Display name. Output only. Name string `json:"name,omitempty"` // Image URL. Output only. ImageURL string `json:"imageUrl,omitempty"` } golang-google-firebase-go-4.18.0/snippets/000077500000000000000000000000001505612111400203475ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/snippets/auth.go000066400000000000000000001333151505612111400216450ustar00rootroot00000000000000// Copyright 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package snippets import ( "context" "encoding/base64" "encoding/json" "io/ioutil" "log" "net/http" "time" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/auth" "firebase.google.com/go/v4/auth/hash" "google.golang.org/api/iterator" ) // ================================================================== // https://firebase.google.com/docs/auth/admin/create-custom-tokens // ================================================================== func createCustomToken(ctx context.Context, app *firebase.App) string { // [START create_custom_token_golang] client, err := app.Auth(context.Background()) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } token, err := client.CustomToken(ctx, "some-uid") if err != nil { log.Fatalf("error minting custom token: %v\n", err) } log.Printf("Got custom token: %v\n", token) // [END create_custom_token_golang] return token } func createCustomTokenWithClaims(ctx context.Context, app *firebase.App) string { // [START create_custom_token_claims_golang] client, err := app.Auth(context.Background()) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } claims := map[string]interface{}{ "premiumAccount": true, } token, err := client.CustomTokenWithClaims(ctx, "some-uid", claims) if err != nil { log.Fatalf("error minting custom token: %v\n", err) } log.Printf("Got custom token: %v\n", token) // [END create_custom_token_claims_golang] return token } // ================================================================== // https://firebase.google.com/docs/auth/admin/verify-id-tokens // ================================================================== func verifyIDToken(ctx context.Context, app *firebase.App, idToken string) *auth.Token { // [START verify_id_token_golang] client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } token, err := client.VerifyIDToken(ctx, idToken) if err != nil { log.Fatalf("error verifying ID token: %v\n", err) } log.Printf("Verified ID token: %v\n", token) // [END verify_id_token_golang] return token } // ================================================================== // https://firebase.google.com/docs/auth/admin/manage-sessions // ================================================================== func revokeRefreshTokens(ctx context.Context, app *firebase.App, uid string) { // [START revoke_tokens_golang] client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } if err := client.RevokeRefreshTokens(ctx, uid); err != nil { log.Fatalf("error revoking tokens for user: %v, %v\n", uid, err) } // accessing the user's TokenValidAfter u, err := client.GetUser(ctx, uid) if err != nil { log.Fatalf("error getting user %s: %v\n", uid, err) } timestamp := u.TokensValidAfterMillis / 1000 log.Printf("the refresh tokens were revoked at: %d (UTC seconds) ", timestamp) // [END revoke_tokens_golang] } func verifyIDTokenAndCheckRevoked(ctx context.Context, app *firebase.App, idToken string) *auth.Token { // [START verify_id_token_and_check_revoked_golang] client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } token, err := client.VerifyIDTokenAndCheckRevoked(ctx, idToken) if err != nil { if err.Error() == "ID token has been revoked" { // Token is revoked. Inform the user to reauthenticate or signOut() the user. } else { // Token is invalid } } log.Printf("Verified ID token: %v\n", token) // [END verify_id_token_and_check_revoked_golang] return token } // ================================================================== // https://firebase.google.com/docs/auth/admin/manage-users // ================================================================== func getUser(ctx context.Context, app *firebase.App) *auth.UserRecord { uid := "some_string_uid" // [START get_user_golang] // Get an auth client from the firebase.App client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } u, err := client.GetUser(ctx, uid) if err != nil { log.Fatalf("error getting user %s: %v\n", uid, err) } log.Printf("Successfully fetched user data: %v\n", u) // [END get_user_golang] return u } func getUserByEmail(ctx context.Context, client *auth.Client) *auth.UserRecord { email := "some@email.com" // [START get_user_by_email_golang] u, err := client.GetUserByEmail(ctx, email) if err != nil { log.Fatalf("error getting user by email %s: %v\n", email, err) } log.Printf("Successfully fetched user data: %v\n", u) // [END get_user_by_email_golang] return u } func getUserByPhone(ctx context.Context, client *auth.Client) *auth.UserRecord { phone := "+13214567890" // [START get_user_by_phone_golang] u, err := client.GetUserByPhoneNumber(ctx, phone) if err != nil { log.Fatalf("error getting user by phone %s: %v\n", phone, err) } log.Printf("Successfully fetched user data: %v\n", u) // [END get_user_by_phone_golang] return u } func bulkGetUsers(ctx context.Context, client *auth.Client) { // [START bulk_get_users_golang] getUsersResult, err := client.GetUsers(ctx, []auth.UserIdentifier{ auth.UIDIdentifier{UID: "uid1"}, auth.EmailIdentifier{Email: "user@example.com"}, auth.PhoneIdentifier{PhoneNumber: "+15555551234"}, auth.ProviderIdentifier{ProviderID: "google.com", ProviderUID: "google_uid1"}, }) if err != nil { log.Fatalf("error retriving multiple users: %v\n", err) } log.Printf("Successfully fetched user data:") for _, u := range getUsersResult.Users { log.Printf("%v", u) } log.Printf("Unable to find users corresponding to these identifiers:") for _, id := range getUsersResult.NotFound { log.Printf("%v", id) } // [END bulk_get_users_golang] } func createUser(ctx context.Context, client *auth.Client) *auth.UserRecord { // [START create_user_golang] params := (&auth.UserToCreate{}). Email("user@example.com"). EmailVerified(false). PhoneNumber("+15555550100"). Password("secretPassword"). DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(false) u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } log.Printf("Successfully created user: %v\n", u) // [END create_user_golang] return u } func createUserWithUID(ctx context.Context, client *auth.Client) *auth.UserRecord { uid := "something" // [START create_user_with_uid_golang] params := (&auth.UserToCreate{}). UID(uid). Email("user@example.com"). PhoneNumber("+15555550100") u, err := client.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } log.Printf("Successfully created user: %v\n", u) // [END create_user_with_uid_golang] return u } func updateUser(ctx context.Context, client *auth.Client) { uid := "d" // [START update_user_golang] params := (&auth.UserToUpdate{}). Email("user@example.com"). EmailVerified(true). PhoneNumber("+15555550100"). Password("newPassword"). DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(true) u, err := client.UpdateUser(ctx, uid, params) if err != nil { log.Fatalf("error updating user: %v\n", err) } log.Printf("Successfully updated user: %v\n", u) // [END update_user_golang] } func deleteUser(ctx context.Context, client *auth.Client) { uid := "d" // [START delete_user_golang] err := client.DeleteUser(ctx, uid) if err != nil { log.Fatalf("error deleting user: %v\n", err) } log.Printf("Successfully deleted user: %s\n", uid) // [END delete_user_golang] } func bulkDeleteUsers(ctx context.Context, client *auth.Client) { // [START bulk_delete_users_golang] deleteUsersResult, err := client.DeleteUsers(ctx, []string{"uid1", "uid2", "uid3"}) if err != nil { log.Fatalf("error deleting users: %v\n", err) } log.Printf("Successfully deleted %d users", deleteUsersResult.SuccessCount) log.Printf("Failed to delete %d users", deleteUsersResult.FailureCount) for _, err := range deleteUsersResult.Errors { log.Printf("%v", err) } // [END bulk_delete_users_golang] } func customClaimsSet(ctx context.Context, app *firebase.App) { uid := "uid" // [START set_custom_user_claims_golang] // Get an auth client from the firebase.App client, err := app.Auth(ctx) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // Set admin privilege on the user corresponding to uid. claims := map[string]interface{}{"admin": true} err = client.SetCustomUserClaims(ctx, uid, claims) if err != nil { log.Fatalf("error setting custom claims %v\n", err) } // The new custom claims will propagate to the user's ID token the // next time a new one is issued. // [END set_custom_user_claims_golang] // erase all existing custom claims } func customClaimsVerify(ctx context.Context, client *auth.Client) { idToken := "token" // [START verify_custom_claims_golang] // Verify the ID token first. token, err := client.VerifyIDToken(ctx, idToken) if err != nil { log.Fatal(err) } claims := token.Claims if admin, ok := claims["admin"]; ok { if admin.(bool) { //Allow access to requested admin resource. } } // [END verify_custom_claims_golang] } func customClaimsRead(ctx context.Context, client *auth.Client) { uid := "uid" // [START read_custom_user_claims_golang] // Lookup the user associated with the specified uid. user, err := client.GetUser(ctx, uid) if err != nil { log.Fatal(err) } // The claims can be accessed on the user record. if admin, ok := user.CustomClaims["admin"]; ok { if admin.(bool) { log.Println(admin) } } // [END read_custom_user_claims_golang] } func customClaimsScript(ctx context.Context, client *auth.Client) { // [START set_custom_user_claims_script_golang] user, err := client.GetUserByEmail(ctx, "user@admin.example.com") if err != nil { log.Fatal(err) } // Confirm user is verified if user.EmailVerified { // Add custom claims for additional privileges. // This will be picked up by the user on token refresh or next sign in on new device. err := client.SetCustomUserClaims(ctx, user.UID, map[string]interface{}{"admin": true}) if err != nil { log.Fatalf("error setting custom claims %v\n", err) } } // [END set_custom_user_claims_script_golang] } func customClaimsIncremental(ctx context.Context, client *auth.Client) { // [START set_custom_user_claims_incremental_golang] user, err := client.GetUserByEmail(ctx, "user@admin.example.com") if err != nil { log.Fatal(err) } // Add incremental custom claim without overwriting existing claims. currentCustomClaims := user.CustomClaims if currentCustomClaims == nil { currentCustomClaims = map[string]interface{}{} } if _, found := currentCustomClaims["admin"]; found { // Add level. currentCustomClaims["accessLevel"] = 10 // Add custom claims for additional privileges. err := client.SetCustomUserClaims(ctx, user.UID, currentCustomClaims) if err != nil { log.Fatalf("error setting custom claims %v\n", err) } } // [END set_custom_user_claims_incremental_golang] } func listUsers(ctx context.Context, client *auth.Client) { // [START list_all_users_golang] // Note, behind the scenes, the Users() iterator will retrive 1000 Users at a time through the API iter := client.Users(ctx, "") for { user, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Fatalf("error listing users: %s\n", err) } log.Printf("read user user: %v\n", user) } // Iterating by pages 100 users at a time. // Note that using both the Next() function on an iterator and the NextPage() // on a Pager wrapping that same iterator will result in an error. pager := iterator.NewPager(client.Users(ctx, ""), 100, "") for { var users []*auth.ExportedUserRecord nextPageToken, err := pager.NextPage(&users) if err != nil { log.Fatalf("paging error %v\n", err) } for _, u := range users { log.Printf("read user user: %v\n", u) } if nextPageToken == "" { break } } // [END list_all_users_golang] } func importUsers(ctx context.Context, app *firebase.App) { // [START build_user_list] // Up to 1000 users can be imported at once. var users []*auth.UserToImport users = append(users, (&auth.UserToImport{}). UID("uid1"). Email("user1@example.com"). PasswordHash([]byte("passwordHash1")). PasswordSalt([]byte("salt1"))) users = append(users, (&auth.UserToImport{}). UID("uid2"). Email("user2@example.com"). PasswordHash([]byte("passwordHash2")). PasswordSalt([]byte("salt2"))) // [END build_user_list] // [START import_users] client, err := app.Auth(ctx) if err != nil { log.Fatalln("Error initializing Auth client", err) } h := hash.HMACSHA256{ Key: []byte("secretKey"), } result, err := client.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Unrecoverable error prevented the operation from running", err) } log.Printf("Successfully imported %d users\n", result.SuccessCount) log.Printf("Failed to import %d users\n", result.FailureCount) for _, e := range result.Errors { log.Printf("Failed to import user at index: %d due to error: %s\n", e.Index, e.Reason) } // [END import_users] } func importWithHMAC(ctx context.Context, client *auth.Client) { // [START import_with_hmac] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). Email("user@example.com"). PasswordHash([]byte("password-hash")). PasswordSalt([]byte("salt")), } h := hash.HMACSHA256{ Key: []byte("secret"), } result, err := client.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_with_hmac] } func importWithPBKDF(ctx context.Context, client *auth.Client) { // [START import_with_pbkdf] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). Email("user@example.com"). PasswordHash([]byte("password-hash")). PasswordSalt([]byte("salt")), } h := hash.PBKDF2SHA256{ Rounds: 100000, } result, err := client.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_with_pbkdf] } func importWithStandardScrypt(ctx context.Context, client *auth.Client) { // [START import_with_standard_scrypt] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). Email("user@example.com"). PasswordHash([]byte("password-hash")). PasswordSalt([]byte("salt")), } h := hash.StandardScrypt{ MemoryCost: 1024, Parallelization: 16, BlockSize: 8, DerivedKeyLength: 64, } result, err := client.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_with_standard_scrypt] } func importWithBcrypt(ctx context.Context, client *auth.Client) { // [START import_with_bcrypt] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). Email("user@example.com"). PasswordHash([]byte("password-hash")). PasswordSalt([]byte("salt")), } h := hash.Bcrypt{} result, err := client.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_with_bcrypt] } func importWithScrypt(ctx context.Context, client *auth.Client) { // [START import_with_scrypt] b64URLdecode := func(s string) []byte { b, err := base64.URLEncoding.DecodeString(s) if err != nil { log.Fatalln("Failed to decode string", err) } return b } b64Stddecode := func(s string) []byte { b, err := base64.StdEncoding.DecodeString(s) if err != nil { log.Fatalln("Failed to decode string", err) } return b } // Users retrieved from Firebase Auth's backend need to be base64URL decoded users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). Email("user@example.com"). PasswordHash(b64URLdecode("password-hash")). PasswordSalt(b64URLdecode("salt")), } // All the parameters below can be obtained from the Firebase Console's "Users" // section. Base64 encoded parameters must be decoded into raw bytes. h := hash.Scrypt{ Key: b64Stddecode("base64-secret"), SaltSeparator: b64Stddecode("base64-salt-separator"), Rounds: 8, MemoryCost: 14, } result, err := client.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_with_scrypt] } func importWithoutPassword(ctx context.Context, client *auth.Client) { // [START import_without_password] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). DisplayName("John Doe"). Email("johndoe@gmail.com"). PhotoURL("http://www.example.com/12345678/photo.png"). EmailVerified(true). PhoneNumber("+11234567890"). CustomClaims(map[string]interface{}{"admin": true}). // set this user as admin ProviderData([]*auth.UserProvider{ // user with Google provider { UID: "google-uid", Email: "johndoe@gmail.com", DisplayName: "John Doe", PhotoURL: "http://www.example.com/12345678/photo.png", ProviderID: "google.com", }, }), } result, err := client.ImportUsers(ctx, users) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_without_password] } func loginHandler(client *auth.Client) http.HandlerFunc { // [START session_login] return func(w http.ResponseWriter, r *http.Request) { // Get the ID token sent by the client defer r.Body.Close() idToken, err := getIDTokenFromBody(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } // Set session expiration to 5 days. expiresIn := time.Hour * 24 * 5 // Create the session cookie. This will also verify the ID token in the process. // The session cookie will have the same claims as the ID token. // To only allow session cookie setting on recent sign-in, auth_time in ID token // can be checked to ensure user was recently signed in before creating a session cookie. cookie, err := client.SessionCookie(r.Context(), idToken, expiresIn) if err != nil { http.Error(w, "Failed to create a session cookie", http.StatusInternalServerError) return } // Set cookie policy for session cookie. http.SetCookie(w, &http.Cookie{ Name: "session", Value: cookie, MaxAge: int(expiresIn.Seconds()), HttpOnly: true, Secure: true, }) w.Write([]byte(`{"status": "success"}`)) } // [END session_login] } func loginWithAuthTimeCheckHandler(client *auth.Client) http.HandlerFunc { // [START check_auth_time] return func(w http.ResponseWriter, r *http.Request) { // Get the ID token sent by the client defer r.Body.Close() idToken, err := getIDTokenFromBody(r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } decoded, err := client.VerifyIDToken(r.Context(), idToken) if err != nil { http.Error(w, "Invalid ID token", http.StatusUnauthorized) return } // Return error if the sign-in is older than 5 minutes. if time.Now().Unix()-decoded.Claims["auth_time"].(int64) > 5*60 { http.Error(w, "Recent sign-in required", http.StatusUnauthorized) return } expiresIn := time.Hour * 24 * 5 cookie, err := client.SessionCookie(r.Context(), idToken, expiresIn) if err != nil { http.Error(w, "Failed to create a session cookie", http.StatusInternalServerError) return } http.SetCookie(w, &http.Cookie{ Name: "session", Value: cookie, MaxAge: int(expiresIn.Seconds()), HttpOnly: true, Secure: true, }) w.Write([]byte(`{"status": "success"}`)) } // [END check_auth_time] } func userProfileHandler(client *auth.Client) http.HandlerFunc { serveContentForUser := func(w http.ResponseWriter, r *http.Request, claims *auth.Token) { log.Println("Serving content") } // [START session_verify] return func(w http.ResponseWriter, r *http.Request) { // Get the ID token sent by the client cookie, err := r.Cookie("session") if err != nil { // Session cookie is unavailable. Force user to login. http.Redirect(w, r, "/login", http.StatusFound) return } // Verify the session cookie. In this case an additional check is added to detect // if the user's Firebase session was revoked, user deleted/disabled, etc. decoded, err := client.VerifySessionCookieAndCheckRevoked(r.Context(), cookie.Value) if err != nil { // Session cookie is invalid. Force user to login. http.Redirect(w, r, "/login", http.StatusFound) return } serveContentForUser(w, r, decoded) } // [END session_verify] } func adminUserHandler(client *auth.Client) http.HandlerFunc { serveContentForAdmin := func(w http.ResponseWriter, r *http.Request, claims *auth.Token) { log.Println("Serving content") } // [START session_verify_with_permission_check] return func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session") if err != nil { // Session cookie is unavailable. Force user to login. http.Redirect(w, r, "/login", http.StatusFound) return } decoded, err := client.VerifySessionCookieAndCheckRevoked(r.Context(), cookie.Value) if err != nil { // Session cookie is invalid. Force user to login. http.Redirect(w, r, "/login", http.StatusFound) return } // Check custom claims to confirm user is an admin. if decoded.Claims["admin"] != true { http.Error(w, "Insufficient permissions", http.StatusUnauthorized) return } serveContentForAdmin(w, r, decoded) } // [END session_verify_with_permission_check] } func sessionLogoutHandler() http.HandlerFunc { // [START session_clear] return func(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: "session", Value: "", MaxAge: 0, }) http.Redirect(w, r, "/login", http.StatusFound) } // [END session_clear] } func sessionLogoutHandlerWithRevocation(client *auth.Client) http.HandlerFunc { // [START session_clear_and_revoke] return func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session") if err != nil { // Session cookie is unavailable. Force user to login. http.Redirect(w, r, "/login", http.StatusFound) return } decoded, err := client.VerifySessionCookie(r.Context(), cookie.Value) if err != nil { // Session cookie is invalid. Force user to login. http.Redirect(w, r, "/login", http.StatusFound) return } if err := client.RevokeRefreshTokens(r.Context(), decoded.UID); err != nil { http.Error(w, "Failed to revoke refresh token", http.StatusInternalServerError) return } http.SetCookie(w, &http.Cookie{ Name: "session", Value: "", MaxAge: 0, }) http.Redirect(w, r, "/login", http.StatusFound) } // [END session_clear_and_revoke] } func getIDTokenFromBody(r *http.Request) (string, error) { b, err := ioutil.ReadAll(r.Body) if err != nil { return "", err } var parsedBody struct { IDToken string `json:"idToken"` } err = json.Unmarshal(b, &parsedBody) return parsedBody.IDToken, err } func newActionCodeSettings() *auth.ActionCodeSettings { // [START init_action_code_settings] actionCodeSettings := &auth.ActionCodeSettings{ URL: "https://www.example.com/checkout?cartId=1234", HandleCodeInApp: true, IOSBundleID: "com.example.ios", AndroidPackageName: "com.example.android", AndroidInstallApp: true, AndroidMinimumVersion: "12", DynamicLinkDomain: "coolapp.page.link", } // [END init_action_code_settings] return actionCodeSettings } func generatePasswordResetLink(ctx context.Context, client *auth.Client) { actionCodeSettings := newActionCodeSettings() displayName := "Example User" // [START password_reset_link] email := "user@example.com" link, err := client.PasswordResetLinkWithSettings(ctx, email, actionCodeSettings) if err != nil { log.Fatalf("error generating email link: %v\n", err) } // Construct password reset template, embed the link and send // using custom SMTP server. sendCustomEmail(email, displayName, link) // [END password_reset_link] } func generateEmailVerificationLink(ctx context.Context, client *auth.Client) { actionCodeSettings := newActionCodeSettings() displayName := "Example User" // [START email_verification_link] email := "user@example.com" link, err := client.EmailVerificationLinkWithSettings(ctx, email, actionCodeSettings) if err != nil { log.Fatalf("error generating email link: %v\n", err) } // Construct email verification template, embed the link and send // using custom SMTP server. sendCustomEmail(email, displayName, link) // [END email_verification_link] } func generateEmailSignInLink(ctx context.Context, client *auth.Client) { actionCodeSettings := newActionCodeSettings() displayName := "Example User" // [START sign_in_with_email_link] email := "user@example.com" link, err := client.EmailSignInLink(ctx, email, actionCodeSettings) if err != nil { log.Fatalf("error generating email link: %v\n", err) } // Construct sign-in with email link template, embed the link and send // using custom SMTP server. sendCustomEmail(email, displayName, link) // [END sign_in_with_email_link] } // Place holder function to make the compiler happy. This is referenced by all email action // link snippets. func sendCustomEmail(email, displayName, link string) {} // ===================================================================================== // https://cloud.google.com/identity-platform/docs/managing-providers-programmatically // ===================================================================================== func createSAMLProviderConfig(ctx context.Context, client *auth.Client) { // [START create_saml_provider] newConfig := (&auth.SAMLProviderConfigToCreate{}). DisplayName("SAML provider name"). Enabled(true). ID("saml.myProvider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/saml/sso/1234/"). X509Certificates([]string{ "-----BEGIN CERTIFICATE-----\nCERT1...\n-----END CERTIFICATE-----", "-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----", }). RPEntityID("RP_ENTITY_ID"). CallbackURL("https://project-id.firebaseapp.com/__/auth/handler") saml, err := client.CreateSAMLProviderConfig(ctx, newConfig) if err != nil { log.Fatalf("error creating SAML provider: %v\n", err) } log.Printf("Created new SAML provider: %s", saml.ID) // [END create_saml_provider] } func updateSAMLProviderConfig(ctx context.Context, client *auth.Client) { // [START update_saml_provider] updatedConfig := (&auth.SAMLProviderConfigToUpdate{}). X509Certificates([]string{ "-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----", "-----BEGIN CERTIFICATE-----\nCERT3...\n-----END CERTIFICATE-----", }) saml, err := client.UpdateSAMLProviderConfig(ctx, "saml.myProvider", updatedConfig) if err != nil { log.Fatalf("error updating SAML provider: %v\n", err) } log.Printf("Updated SAML provider: %s", saml.ID) // [END update_saml_provider] } func getSAMLProviderConfig(ctx context.Context, client *auth.Client) { // [START get_saml_provider] saml, err := client.SAMLProviderConfig(ctx, "saml.myProvider") if err != nil { log.Fatalf("error retrieving SAML provider: %v\n", err) } log.Printf("%s %t", saml.DisplayName, saml.Enabled) // [END get_saml_provider] } func deleteSAMLProviderConfig(ctx context.Context, client *auth.Client) { // [START delete_saml_provider] if err := client.DeleteSAMLProviderConfig(ctx, "saml.myProvider"); err != nil { log.Fatalf("error deleting SAML provider: %v\n", err) } // [END delete_saml_provider] } func listSAMLProviderConfigs(ctx context.Context, client *auth.Client) { // [START list_saml_providers] iter := client.SAMLProviderConfigs(ctx, "nextPageToken") for { saml, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Fatalf("error retrieving SAML providers: %v\n", err) } log.Printf("%s\n", saml.ID) } // [END list_saml_providers] } func createOIDCProviderConfig(ctx context.Context, client *auth.Client) { // [START create_oidc_provider] newConfig := (&auth.OIDCProviderConfigToCreate{}). DisplayName("OIDC provider name"). Enabled(true). ID("oidc.myProvider"). ClientID("CLIENT_ID2"). Issuer("https://oidc.com/CLIENT_ID2") oidc, err := client.CreateOIDCProviderConfig(ctx, newConfig) if err != nil { log.Fatalf("error creating OIDC provider: %v\n", err) } log.Printf("Created new OIDC provider: %s", oidc.ID) // [END create_oidc_provider] } func updateOIDCProviderConfig(ctx context.Context, client *auth.Client) { // [START update_oidc_provider] updatedConfig := (&auth.OIDCProviderConfigToUpdate{}). DisplayName("OIDC provider name"). Enabled(true). ClientID("CLIENT_ID"). Issuer("https://oidc.com") oidc, err := client.UpdateOIDCProviderConfig(ctx, "oidc.myProvider", updatedConfig) if err != nil { log.Fatalf("error updating OIDC provider: %v\n", err) } log.Printf("Updated OIDC provider: %s", oidc.ID) // [END update_oidc_provider] } func getOIDCProviderConfig(ctx context.Context, client *auth.Client) { // [START get_oidc_provider] oidc, err := client.OIDCProviderConfig(ctx, "oidc.myProvider") if err != nil { log.Fatalf("error retrieving OIDC provider: %v\n", err) } log.Printf("%s %t", oidc.DisplayName, oidc.Enabled) // [END get_oidc_provider] } func deleteOIDCProviderConfig(ctx context.Context, client *auth.Client) { // [START delete_oidc_provider] if err := client.DeleteOIDCProviderConfig(ctx, "oidc.myProvider"); err != nil { log.Fatalf("error deleting OIDC provider: %v\n", err) } // [END delete_oidc_provider] } func listOIDCProviderConfigs(ctx context.Context, client *auth.Client) { // [START list_oidc_providers] iter := client.OIDCProviderConfigs(ctx, "nextPageToken") for { oidc, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Fatalf("error retrieving OIDC providers: %v\n", err) } log.Printf("%s\n", oidc.ID) } // [END list_oidc_providers] } // ================================================================================ // https://cloud.google.com/identity-platform/docs/multi-tenancy-managing-tenants // ================================================================================= func getTenantClient(ctx context.Context, app *firebase.App, tenantID string) *auth.TenantClient { // [START get_tenant_client] client, err := app.Auth(ctx) if err != nil { log.Fatalf("error initializing auth client: %v\n", err) } tm := client.TenantManager tenantClient, err := tm.AuthForTenant(tenantID) if err != nil { log.Fatalf("error initializing tenant-aware auth client: %v\n", err) } // [END get_tenant_client] return tenantClient } func getTenant(ctx context.Context, client *auth.Client, tenantID string) { // [START get_tenant] tenant, err := client.TenantManager.Tenant(ctx, tenantID) if err != nil { log.Fatalf("error retrieving tenant: %v\n", err) } log.Printf("Retreieved tenant: %q\n", tenant.ID) // [END get_tenant] } func createTenant(ctx context.Context, client *auth.Client) { // [START create_tenant] config := (&auth.TenantToCreate{}). DisplayName("myTenant1"). EnableEmailLinkSignIn(true). AllowPasswordSignUp(true) tenant, err := client.TenantManager.CreateTenant(ctx, config) if err != nil { log.Fatalf("error creating tenant: %v\n", err) } log.Printf("Created tenant: %q\n", tenant.ID) // [END create_tenant] } func updateTenant(ctx context.Context, client *auth.Client, tenantID string) { // [START update_tenant] config := (&auth.TenantToUpdate{}). DisplayName("updatedName"). AllowPasswordSignUp(false) // Disable email provider tenant, err := client.TenantManager.UpdateTenant(ctx, tenantID, config) if err != nil { log.Fatalf("error updating tenant: %v\n", err) } log.Printf("Updated tenant: %q\n", tenant.ID) // [END update_tenant] } func deleteTenant(ctx context.Context, client *auth.Client, tenantID string) { // [START delete_tenant] if err := client.TenantManager.DeleteTenant(ctx, tenantID); err != nil { log.Fatalf("error deleting tenant: %v\n", tenantID) } // [END delete_tenant] } func listTenants(ctx context.Context, client *auth.Client) { // [START list_tenants] iter := client.TenantManager.Tenants(ctx, "") for { tenant, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Fatalf("error listing tenants: %v\n", err) } log.Printf("Retrieved tenant: %q\n", tenant.ID) } // [END list_tenants] } func createProviderTenant(ctx context.Context, client *auth.Client) { // [START get_tenant_client_short] tenantClient, err := client.TenantManager.AuthForTenant("TENANT-ID") if err != nil { log.Fatalf("error initializing tenant client: %v\n", err) } // [END get_tenant_client_short] // [START create_saml_provider_tenant] newConfig := (&auth.SAMLProviderConfigToCreate{}). DisplayName("SAML provider name"). Enabled(true). ID("saml.myProvider"). IDPEntityID("IDP_ENTITY_ID"). SSOURL("https://example.com/saml/sso/1234/"). X509Certificates([]string{ "-----BEGIN CERTIFICATE-----\nCERT1...\n-----END CERTIFICATE-----", "-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----", }). RPEntityID("RP_ENTITY_ID"). // Using the default callback URL. CallbackURL("https://project-id.firebaseapp.com/__/auth/handler") saml, err := tenantClient.CreateSAMLProviderConfig(ctx, newConfig) if err != nil { log.Fatalf("error creating SAML provider: %v\n", err) } log.Printf("Created new SAML provider: %s", saml.ID) // [END create_saml_provider_tenant] } func updateProviderTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START update_saml_provider_tenant] updatedConfig := (&auth.SAMLProviderConfigToUpdate{}). X509Certificates([]string{ "-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----", "-----BEGIN CERTIFICATE-----\nCERT3...\n-----END CERTIFICATE-----", }) saml, err := tenantClient.UpdateSAMLProviderConfig(ctx, "saml.myProvider", updatedConfig) if err != nil { log.Fatalf("error updating SAML provider: %v\n", err) } log.Printf("Updated SAML provider: %s", saml.ID) // [END update_saml_provider_tenant] } func getProviderTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START get_saml_provider_tenant] saml, err := tenantClient.SAMLProviderConfig(ctx, "saml.myProvider") if err != nil { log.Fatalf("error retrieving SAML provider: %v\n", err) } // Get display name and whether it is enabled. log.Printf("%s %t", saml.DisplayName, saml.Enabled) // [END get_saml_provider_tenant] } func listProviderConfigsTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START list_saml_providers_tenant] iter := tenantClient.SAMLProviderConfigs(ctx, "nextPageToken") for { saml, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Fatalf("error retrieving SAML providers: %v\n", err) } log.Printf("%s\n", saml.ID) } // [END list_saml_providers_tenant] } func deleteProviderConfigTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START delete_saml_provider_tenant] if err := tenantClient.DeleteSAMLProviderConfig(ctx, "saml.myProvider"); err != nil { log.Fatalf("error deleting SAML provider: %v\n", err) } // [END delete_saml_provider_tenant] } func getUserTenant(ctx context.Context, tenantClient *auth.TenantClient) *auth.UserRecord { uid := "some_string_uid" // [START get_user_tenant] // Get an auth client from the firebase.App u, err := tenantClient.GetUser(ctx, uid) if err != nil { log.Fatalf("error getting user %s: %v\n", uid, err) } log.Printf("Successfully fetched user data: %v\n", u) // [END get_user_tenant] return u } func getUserByEmailTenant(ctx context.Context, tenantClient *auth.TenantClient) *auth.UserRecord { email := "some@email.com" // [START get_user_by_email_tenant] u, err := tenantClient.GetUserByEmail(ctx, email) if err != nil { log.Fatalf("error getting user by email %s: %v\n", email, err) } log.Printf("Successfully fetched user data: %v\n", u) // [END get_user_by_email_tenant] return u } func createUserTenant(ctx context.Context, tenantClient *auth.TenantClient) *auth.UserRecord { // [START create_user_tenant] params := (&auth.UserToCreate{}). Email("user@example.com"). EmailVerified(false). PhoneNumber("+15555550100"). Password("secretPassword"). DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(false) u, err := tenantClient.CreateUser(ctx, params) if err != nil { log.Fatalf("error creating user: %v\n", err) } log.Printf("Successfully created user: %v\n", u) // [END create_user_tenant] return u } func updateUserTenant(ctx context.Context, tenantClient *auth.TenantClient, uid string) { // [START update_user_tenant] params := (&auth.UserToUpdate{}). Email("user@example.com"). EmailVerified(true). PhoneNumber("+15555550100"). Password("newPassword"). DisplayName("John Doe"). PhotoURL("http://www.example.com/12345678/photo.png"). Disabled(true) u, err := tenantClient.UpdateUser(ctx, uid, params) if err != nil { log.Fatalf("error updating user: %v\n", err) } log.Printf("Successfully updated user: %v\n", u) // [END update_user_tenant] } func deleteUserTenant(ctx context.Context, tenantClient *auth.TenantClient, uid string) { // [START delete_user_tenant] if err := tenantClient.DeleteUser(ctx, uid); err != nil { log.Fatalf("error deleting user: %v\n", err) } log.Printf("Successfully deleted user: %s\n", uid) // [END delete_user_tenant] } func listUsersTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START list_all_users_tenant] // Note, behind the scenes, the Users() iterator will retrive 1000 Users at a time through the API iter := tenantClient.Users(ctx, "") for { user, err := iter.Next() if err == iterator.Done { break } if err != nil { log.Fatalf("error listing users: %s\n", err) } log.Printf("read user user: %v\n", user) } // Iterating by pages 100 users at a time. // Note that using both the Next() function on an iterator and the NextPage() // on a Pager wrapping that same iterator will result in an error. pager := iterator.NewPager(tenantClient.Users(ctx, ""), 100, "") for { var users []*auth.ExportedUserRecord nextPageToken, err := pager.NextPage(&users) if err != nil { log.Fatalf("paging error %v\n", err) } for _, u := range users { log.Printf("read user user: %v\n", u) } if nextPageToken == "" { break } } // [END list_all_users_tenant] } func importWithHMACTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START import_with_hmac_tenant] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("uid1"). Email("user1@example.com"). PasswordHash([]byte("password-hash-1")). PasswordSalt([]byte("salt1")), (&auth.UserToImport{}). UID("uid2"). Email("user2@example.com"). PasswordHash([]byte("password-hash-2")). PasswordSalt([]byte("salt2")), } h := hash.HMACSHA256{ Key: []byte("secret"), } result, err := tenantClient.ImportUsers(ctx, users, auth.WithHash(h)) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_with_hmac_tenant] } func importWithoutPasswordTenant(ctx context.Context, tenantClient *auth.TenantClient) { // [START import_without_password_tenant] users := []*auth.UserToImport{ (&auth.UserToImport{}). UID("some-uid"). DisplayName("John Doe"). Email("johndoe@acme.com"). PhotoURL("https://www.example.com/12345678/photo.png"). EmailVerified(true). PhoneNumber("+11234567890"). // Set this user as admin. CustomClaims(map[string]interface{}{"admin": true}). // User with SAML provider. ProviderData([]*auth.UserProvider{ { UID: "saml-uid", Email: "johndoe@acme.com", DisplayName: "John Doe", PhotoURL: "https://www.example.com/12345678/photo.png", ProviderID: "saml.acme", }, }), } result, err := tenantClient.ImportUsers(ctx, users) if err != nil { log.Fatalln("Error importing users", err) } for _, e := range result.Errors { log.Println("Failed to import user", e.Reason) } // [END import_without_password_tenant] } func verifyIDTokenTenant(ctx context.Context, tenantClient *auth.TenantClient, idToken string) { // [START verify_id_token_tenant] // idToken comes from the client app token, err := tenantClient.VerifyIDToken(ctx, idToken) if err != nil { log.Fatalf("error verifying ID token: %v\n", err) } // This should be set to TENANT-ID. Otherwise auth/mismatching-tenant-id error thrown. log.Printf("Verified ID token from tenant: %q\n", token.Firebase.Tenant) // [END verify_id_token_tenant] } func verifyIDTokenAccessControlTenant(ctx context.Context, tenantClient *auth.TenantClient, idToken string) { // [START id_token_access_control_tenant] token, err := tenantClient.VerifyIDToken(ctx, idToken) if err != nil { log.Fatalf("error verifying ID token: %v\n", err) } if token.Firebase.Tenant == "TENANT-ID1" { // Allow appropriate level of access for TENANT-ID1. } else if token.Firebase.Tenant == "TENANT-ID2" { // Allow appropriate level of access for TENANT-ID2. } else { // Access not allowed -- Handle error } // [END id_token_access_control_tenant] } func revokeRefreshTokensTenant(ctx context.Context, tenantClient *auth.TenantClient, uid string) { // [START revoke_tokens_tenant] // Revoke all refresh tokens for a specified user in a specified tenant for whatever reason. // Retrieve the timestamp of the revocation, in seconds since the epoch. if err := tenantClient.RevokeRefreshTokens(ctx, uid); err != nil { log.Fatalf("error revoking tokens for user: %v, %v\n", uid, err) } // accessing the user's TokenValidAfter u, err := tenantClient.GetUser(ctx, uid) if err != nil { log.Fatalf("error getting user %s: %v\n", uid, err) } timestamp := u.TokensValidAfterMillis / 1000 log.Printf("the refresh tokens were revoked at: %d (UTC seconds) ", timestamp) // [END revoke_tokens_tenant] } func verifyIDTokenAndCheckRevokedTenant(ctx context.Context, tenantClient *auth.TenantClient, idToken string) { // [START verify_id_token_and_check_revoked_tenant] // Verify the ID token for a specific tenant while checking if the token is revoked. token, err := tenantClient.VerifyIDTokenAndCheckRevoked(ctx, idToken) if err != nil { if auth.IsIDTokenRevoked(err) { // Token is revoked. Inform the user to reauthenticate or signOut() the user. } else if auth.IsUserDisabled(err) { // Token is disabled. } else { // Token is invalid } } log.Printf("Verified ID token: %v\n", token) // [END verify_id_token_and_check_revoked_tenant] } func customClaimsSetTenant(ctx context.Context, tenantClient *auth.TenantClient, uid string) { // [START set_custom_user_claims_tenant] // Set admin privilege on the user corresponding to uid. claims := map[string]interface{}{"admin": true} if err := tenantClient.SetCustomUserClaims(ctx, uid, claims); err != nil { log.Fatalf("error setting custom claims %v\n", err) } // The new custom claims will propagate to the user's ID token the // next time a new one is issued. // [END set_custom_user_claims_tenant] } func customClaimsVerifyTenant(ctx context.Context, tenantClient *auth.TenantClient, idToken string) { // [START verify_custom_claims_tenant] // Verify the ID token first. token, err := tenantClient.VerifyIDToken(ctx, idToken) if err != nil { log.Fatal(err) } claims := token.Claims if admin, ok := claims["admin"]; ok { if admin.(bool) { //Allow access to requested admin resource. } } // [END verify_custom_claims_tenant] } func customClaimsReadTenant(ctx context.Context, tenantClient *auth.TenantClient, uid string) { // [START read_custom_user_claims_tenant] // Lookup the user associated with the specified uid. user, err := tenantClient.GetUser(ctx, uid) if err != nil { log.Fatal(err) } // The claims can be accessed on the user record. if admin, ok := user.CustomClaims["admin"]; ok { if admin.(bool) { log.Println(admin) } } // [END read_custom_user_claims_tenant] } func generateEmailVerificationLinkTenant(ctx context.Context, tenantClient *auth.TenantClient) { displayName := "Example User" email := "user@example.com" // [START email_verification_link_tenant] actionCodeSettings := &auth.ActionCodeSettings{ // URL you want to redirect back to. The domain (www.example.com) for // this URL must be whitelisted in the GCP Console. URL: "https://www.example.com/checkout?cartId=1234", // This must be true for email link sign-in. HandleCodeInApp: true, IOSBundleID: "com.example.ios", AndroidPackageName: "com.example.android", AndroidInstallApp: true, AndroidMinimumVersion: "12", // FDL custom domain. DynamicLinkDomain: "coolapp.page.link", } link, err := tenantClient.EmailVerificationLinkWithSettings(ctx, email, actionCodeSettings) if err != nil { log.Fatalf("error generating email link: %v\n", err) } // Construct email verification template, embed the link and send // using custom SMTP server. sendCustomEmail(email, displayName, link) // [END email_verification_link_tenant] } golang-google-firebase-go-4.18.0/snippets/db.go000066400000000000000000000340611505612111400212670ustar00rootroot00000000000000// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package snippets // [START authenticate_db_imports] import ( "context" "fmt" "log" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/db" "google.golang.org/api/option" ) // [END authenticate_db_imports] func authenticateWithAdminPrivileges() { // [START authenticate_with_admin_privileges] ctx := context.Background() conf := &firebase.Config{ DatabaseURL: "https://databaseName.firebaseio.com", } // Fetch the service account key JSON file contents opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") // Initialize the app with a service account, granting admin privileges app, err := firebase.NewApp(ctx, conf, opt) if err != nil { log.Fatalln("Error initializing app:", err) } client, err := app.Database(ctx) if err != nil { log.Fatalln("Error initializing database client:", err) } // As an admin, the app has access to read and write all data, regradless of Security Rules ref := client.NewRef("restricted_access/secret_document") var data map[string]interface{} if err := ref.Get(ctx, &data); err != nil { log.Fatalln("Error reading from database:", err) } fmt.Println(data) // [END authenticate_with_admin_privileges] } func authenticateWithLimitedPrivileges() { // [START authenticate_with_limited_privileges] ctx := context.Background() // Initialize the app with a custom auth variable, limiting the server's access ao := map[string]interface{}{"uid": "my-service-worker"} conf := &firebase.Config{ DatabaseURL: "https://databaseName.firebaseio.com", AuthOverride: &ao, } // Fetch the service account key JSON file contents opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") app, err := firebase.NewApp(ctx, conf, opt) if err != nil { log.Fatalln("Error initializing app:", err) } client, err := app.Database(ctx) if err != nil { log.Fatalln("Error initializing database client:", err) } // The app only has access as defined in the Security Rules ref := client.NewRef("/some_resource") var data map[string]interface{} if err := ref.Get(ctx, &data); err != nil { log.Fatalln("Error reading from database:", err) } fmt.Println(data) // [END authenticate_with_limited_privileges] } func authenticateWithGuestPrivileges() { // [START authenticate_with_guest_privileges] ctx := context.Background() // Initialize the app with a nil auth variable, limiting the server's access var nilMap map[string]interface{} conf := &firebase.Config{ DatabaseURL: "https://databaseName.firebaseio.com", AuthOverride: &nilMap, } // Fetch the service account key JSON file contents opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") app, err := firebase.NewApp(ctx, conf, opt) if err != nil { log.Fatalln("Error initializing app:", err) } client, err := app.Database(ctx) if err != nil { log.Fatalln("Error initializing database client:", err) } // The app only has access to public data as defined in the Security Rules ref := client.NewRef("/some_resource") var data map[string]interface{} if err := ref.Get(ctx, &data); err != nil { log.Fatalln("Error reading from database:", err) } fmt.Println(data) // [END authenticate_with_guest_privileges] } func getReference(ctx context.Context, app *firebase.App) { // [START get_reference] // Create a database client from App. client, err := app.Database(ctx) if err != nil { log.Fatalln("Error initializing database client:", err) } // Get a database reference to our blog. ref := client.NewRef("server/saving-data/fireblog") // [END get_reference] fmt.Println(ref.Path) } // [START user_type] // User is a json-serializable type. type User struct { DateOfBirth string `json:"date_of_birth,omitempty"` FullName string `json:"full_name,omitempty"` Nickname string `json:"nickname,omitempty"` } // [END user_type] func setValue(ctx context.Context, ref *db.Ref) { // [START set_value] usersRef := ref.Child("users") err := usersRef.Set(ctx, map[string]*User{ "alanisawesome": { DateOfBirth: "June 23, 1912", FullName: "Alan Turing", }, "gracehop": { DateOfBirth: "December 9, 1906", FullName: "Grace Hopper", }, }) if err != nil { log.Fatalln("Error setting value:", err) } // [END set_value] } func setChildValue(ctx context.Context, usersRef *db.Ref) { // [START set_child_value] if err := usersRef.Child("alanisawesome").Set(ctx, &User{ DateOfBirth: "June 23, 1912", FullName: "Alan Turing", }); err != nil { log.Fatalln("Error setting value:", err) } if err := usersRef.Child("gracehop").Set(ctx, &User{ DateOfBirth: "December 9, 1906", FullName: "Grace Hopper", }); err != nil { log.Fatalln("Error setting value:", err) } // [END set_child_value] } func updateChild(ctx context.Context, usersRef *db.Ref) { // [START update_child] hopperRef := usersRef.Child("gracehop") if err := hopperRef.Update(ctx, map[string]interface{}{ "nickname": "Amazing Grace", }); err != nil { log.Fatalln("Error updating child:", err) } // [END update_child] } func updateChildren(ctx context.Context, usersRef *db.Ref) { // [START update_children] if err := usersRef.Update(ctx, map[string]interface{}{ "alanisawesome/nickname": "Alan The Machine", "gracehop/nickname": "Amazing Grace", }); err != nil { log.Fatalln("Error updating children:", err) } // [END update_children] } func overwriteValue(ctx context.Context, usersRef *db.Ref) { // [START overwrite_value] if err := usersRef.Update(ctx, map[string]interface{}{ "alanisawesome": &User{Nickname: "Alan The Machine"}, "gracehop": &User{Nickname: "Amazing Grace"}, }); err != nil { log.Fatalln("Error updating children:", err) } // [END overwrite_value] } // [START post_type] // Post is a json-serializable type. type Post struct { Author string `json:"author,omitempty"` Title string `json:"title,omitempty"` } // [END post_type] func pushValue(ctx context.Context, ref *db.Ref) { // [START push_value] postsRef := ref.Child("posts") newPostRef, err := postsRef.Push(ctx, nil) if err != nil { log.Fatalln("Error pushing child node:", err) } if err := newPostRef.Set(ctx, &Post{ Author: "gracehop", Title: "Announcing COBOL, a New Programming Language", }); err != nil { log.Fatalln("Error setting value:", err) } // We can also chain the two calls together if _, err := postsRef.Push(ctx, &Post{ Author: "alanisawesome", Title: "The Turing Machine", }); err != nil { log.Fatalln("Error pushing child node:", err) } // [END push_value] } func pushAndSetValue(ctx context.Context, postsRef *db.Ref) { // [START push_and_set_value] if _, err := postsRef.Push(ctx, &Post{ Author: "gracehop", Title: "Announcing COBOL, a New Programming Language", }); err != nil { log.Fatalln("Error pushing child node:", err) } // [END push_and_set_value] } func pushKey(ctx context.Context, postsRef *db.Ref) { // [START push_key] // Generate a reference to a new location and add some data using Push() newPostRef, err := postsRef.Push(ctx, nil) if err != nil { log.Fatalln("Error pushing child node:", err) } // Get the unique key generated by Push() postID := newPostRef.Key // [END push_key] fmt.Println(postID) } func transaction(ctx context.Context, client *db.Client) { // [START transaction] fn := func(t db.TransactionNode) (interface{}, error) { var currentValue int if err := t.Unmarshal(¤tValue); err != nil { return nil, err } return currentValue + 1, nil } ref := client.NewRef("server/saving-data/fireblog/posts/-JRHTHaIs-jNPLXOQivY/upvotes") if err := ref.Transaction(ctx, fn); err != nil { log.Fatalln("Transaction failed to commit:", err) } // [END transaction] } func readValue(ctx context.Context, app *firebase.App) { // [START read_value] // Create a database client from App. client, err := app.Database(ctx) if err != nil { log.Fatalln("Error initializing database client:", err) } // Get a database reference to our posts ref := client.NewRef("server/saving-data/fireblog/posts") // Read the data at the posts reference (this is a blocking operation) var post Post if err := ref.Get(ctx, &post); err != nil { log.Fatalln("Error reading value:", err) } // [END read_value] fmt.Println(ref.Path) } // [START dinosaur_type] // Dinosaur is a json-serializable type. type Dinosaur struct { Height int `json:"height"` Width int `json:"width"` } // [END dinosaur_type] func orderByChild(ctx context.Context, client *db.Client) { // [START order_by_child] ref := client.NewRef("dinosaurs") results, err := ref.OrderByChild("height").GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { var d Dinosaur if err := r.Unmarshal(&d); err != nil { log.Fatalln("Error unmarshaling result:", err) } fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) } // [END order_by_child] } func orderByNestedChild(ctx context.Context, client *db.Client) { // [START order_by_nested_child] ref := client.NewRef("dinosaurs") results, err := ref.OrderByChild("dimensions/height").GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { var d Dinosaur if err := r.Unmarshal(&d); err != nil { log.Fatalln("Error unmarshaling result:", err) } fmt.Printf("%s was %d meteres tall", r.Key(), d.Height) } // [END order_by_nested_child] } func orderByKey(ctx context.Context, client *db.Client) { // [START order_by_key] ref := client.NewRef("dinosaurs") results, err := ref.OrderByKey().GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } snapshot := make([]Dinosaur, len(results)) for i, r := range results { var d Dinosaur if err := r.Unmarshal(&d); err != nil { log.Fatalln("Error unmarshaling result:", err) } snapshot[i] = d } fmt.Println(snapshot) // [END order_by_key] } func orderByValue(ctx context.Context, client *db.Client) { // [START order_by_value] ref := client.NewRef("scores") results, err := ref.OrderByValue().GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { var score int if err := r.Unmarshal(&score); err != nil { log.Fatalln("Error unmarshaling result:", err) } fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) } // [END order_by_value] } func limitToLast(ctx context.Context, client *db.Client) { // [START limit_query_1] ref := client.NewRef("dinosaurs") results, err := ref.OrderByChild("weight").LimitToLast(2).GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { fmt.Println(r.Key()) } // [END limit_query_1] } func limitToFirst(ctx context.Context, client *db.Client) { // [START limit_query_2] ref := client.NewRef("dinosaurs") results, err := ref.OrderByChild("height").LimitToFirst(2).GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { fmt.Println(r.Key()) } // [END limit_query_2] } func limitWithValueOrder(ctx context.Context, client *db.Client) { // [START limit_query_3] ref := client.NewRef("scores") results, err := ref.OrderByValue().LimitToLast(3).GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { var score int if err := r.Unmarshal(&score); err != nil { log.Fatalln("Error unmarshaling result:", err) } fmt.Printf("The %s dinosaur's score is %d\n", r.Key(), score) } // [END limit_query_3] } func startAt(ctx context.Context, client *db.Client) { // [START range_query_1] ref := client.NewRef("dinosaurs") results, err := ref.OrderByChild("height").StartAt(3).GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { fmt.Println(r.Key()) } // [END range_query_1] } func endAt(ctx context.Context, client *db.Client) { // [START range_query_2] ref := client.NewRef("dinosaurs") results, err := ref.OrderByKey().EndAt("pterodactyl").GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { fmt.Println(r.Key()) } // [END range_query_2] } func startAndEndAt(ctx context.Context, client *db.Client) { // [START range_query_3] ref := client.NewRef("dinosaurs") results, err := ref.OrderByKey().StartAt("b").EndAt("b\uf8ff").GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { fmt.Println(r.Key()) } // [END range_query_3] } func equalTo(ctx context.Context, client *db.Client) { // [START range_query_4] ref := client.NewRef("dinosaurs") results, err := ref.OrderByChild("height").EqualTo(25).GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } for _, r := range results { fmt.Println(r.Key()) } // [END range_query_4] } func complexQuery(ctx context.Context, client *db.Client) { // [START complex_query] ref := client.NewRef("dinosaurs") var favDinoHeight int if err := ref.Child("stegosaurus").Child("height").Get(ctx, &favDinoHeight); err != nil { log.Fatalln("Error querying database:", err) } query := ref.OrderByChild("height").EndAt(favDinoHeight).LimitToLast(2) results, err := query.GetOrdered(ctx) if err != nil { log.Fatalln("Error querying database:", err) } if len(results) == 2 { // Data is ordered by increasing height, so we want the first entry. // Second entry is stegosarus. fmt.Printf("The dinosaur just shorter than the stegosaurus is %s\n", results[0].Key()) } else { fmt.Println("The stegosaurus is the shortest dino") } // [END complex_query] } golang-google-firebase-go-4.18.0/snippets/init.go000066400000000000000000000075031505612111400216460ustar00rootroot00000000000000// Copyright 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package snippets // [START admin_import_golang] import ( "context" "log" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/auth" "google.golang.org/api/option" ) // [END admin_import_golang] // ================================================================== // https://firebase.google.com/docs/admin/setup // ================================================================== func initializeAppWithServiceAccount() *firebase.App { // [START initialize_app_service_account_golang] opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") app, err := firebase.NewApp(context.Background(), nil, opt) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // [END initialize_app_service_account_golang] return app } func initializeAppWithRefreshToken() *firebase.App { // [START initialize_app_refresh_token_golang] opt := option.WithCredentialsFile("path/to/refreshToken.json") config := &firebase.Config{ProjectID: "my-project-id"} app, err := firebase.NewApp(context.Background(), config, opt) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // [END initialize_app_refresh_token_golang] return app } func initializeAppDefault() *firebase.App { // [START initialize_app_default_golang] app, err := firebase.NewApp(context.Background(), nil) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // [END initialize_app_default_golang] return app } func initializeServiceAccountID() *firebase.App { // [START initialize_sdk_with_service_account_id] conf := &firebase.Config{ ServiceAccountID: "my-client-id@my-project-id.iam.gserviceaccount.com", } app, err := firebase.NewApp(context.Background(), conf) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // [END initialize_sdk_with_service_account_id] return app } func accessServicesSingleApp() (*auth.Client, error) { // [START access_services_single_app_golang] // Initialize default app app, err := firebase.NewApp(context.Background(), nil) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // Access auth service from the default app client, err := app.Auth(context.Background()) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // [END access_services_single_app_golang] return client, err } func accessServicesMultipleApp() (*auth.Client, error) { // [START access_services_multiple_app_golang] // Initialize the default app defaultApp, err := firebase.NewApp(context.Background(), nil) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // Initialize another app with a different config opt := option.WithCredentialsFile("service-account-other.json") otherApp, err := firebase.NewApp(context.Background(), nil, opt) if err != nil { log.Fatalf("error initializing app: %v\n", err) } // Access Auth service from default app defaultClient, err := defaultApp.Auth(context.Background()) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // Access auth service from other app otherClient, err := otherApp.Auth(context.Background()) if err != nil { log.Fatalf("error getting Auth client: %v\n", err) } // [END access_services_multiple_app_golang] // Avoid unused _ = defaultClient return otherClient, nil } golang-google-firebase-go-4.18.0/snippets/messaging.go000066400000000000000000000265231505612111400226630ustar00rootroot00000000000000// Copyright 2018 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package snippets import ( "context" "fmt" "log" "time" firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/messaging" ) func sendToToken(app *firebase.App) { // [START send_to_token_golang] // Obtain a messaging.Client from the App. ctx := context.Background() client, err := app.Messaging(ctx) if err != nil { log.Fatalf("error getting Messaging client: %v\n", err) } // This registration token comes from the client FCM SDKs. registrationToken := "YOUR_REGISTRATION_TOKEN" // See documentation on defining a message payload. message := &messaging.Message{ Data: map[string]string{ "score": "850", "time": "2:45", }, Token: registrationToken, } // Send a message to the device corresponding to the provided // registration token. response, err := client.Send(ctx, message) if err != nil { log.Fatalln(err) } // Response is a message ID string. fmt.Println("Successfully sent message:", response) // [END send_to_token_golang] } func sendToTopic(ctx context.Context, client *messaging.Client) { // [START send_to_topic_golang] // The topic name can be optionally prefixed with "/topics/". topic := "highScores" // See documentation on defining a message payload. message := &messaging.Message{ Data: map[string]string{ "score": "850", "time": "2:45", }, Topic: topic, } // Send a message to the devices subscribed to the provided topic. response, err := client.Send(ctx, message) if err != nil { log.Fatalln(err) } // Response is a message ID string. fmt.Println("Successfully sent message:", response) // [END send_to_topic_golang] } func sendToCondition(ctx context.Context, client *messaging.Client) { // [START send_to_condition_golang] // Define a condition which will send to devices which are subscribed // to either the Google stock or the tech industry topics. condition := "'stock-GOOG' in topics || 'industry-tech' in topics" // See documentation on defining a message payload. message := &messaging.Message{ Data: map[string]string{ "score": "850", "time": "2:45", }, Condition: condition, } // Send a message to devices subscribed to the combination of topics // specified by the provided condition. response, err := client.Send(ctx, message) if err != nil { log.Fatalln(err) } // Response is a message ID string. fmt.Println("Successfully sent message:", response) // [END send_to_condition_golang] } func sendAll(ctx context.Context, client *messaging.Client) { // This registration token comes from the client FCM SDKs. registrationToken := "YOUR_REGISTRATION_TOKEN" // [START send_all] // Create a list containing up to 500 messages. messages := []*messaging.Message{ { Notification: &messaging.Notification{ Title: "Price drop", Body: "5% off all electronics", }, Token: registrationToken, }, { Notification: &messaging.Notification{ Title: "Price drop", Body: "2% off all books", }, Topic: "readers-club", }, } br, err := client.SendAll(context.Background(), messages) if err != nil { log.Fatalln(err) } // See the BatchResponse reference documentation // for the contents of response. fmt.Printf("%d messages were sent successfully\n", br.SuccessCount) // [END send_all] } func sendEach(ctx context.Context, client *messaging.Client) { // This registration token comes from the client FCM SDKs. registrationToken := "YOUR_REGISTRATION_TOKEN" // [START send_each] // Create a list containing up to 500 messages. messages := []*messaging.Message{ { Notification: &messaging.Notification{ Title: "Price drop", Body: "5% off all electronics", }, Token: registrationToken, }, { Notification: &messaging.Notification{ Title: "Price drop", Body: "2% off all books", }, Topic: "readers-club", }, } br, err := client.SendEach(context.Background(), messages) if err != nil { log.Fatalln(err) } // See the BatchResponse reference documentation // for the contents of response. fmt.Printf("%d messages were sent successfully\n", br.SuccessCount) // [END send_each] } func sendMulticast(ctx context.Context, client *messaging.Client) { // [START send_multicast] // Create a list containing up to 500 registration tokens. // This registration tokens come from the client FCM SDKs. registrationTokens := []string{ "YOUR_REGISTRATION_TOKEN_1", // ... "YOUR_REGISTRATION_TOKEN_n", } message := &messaging.MulticastMessage{ Data: map[string]string{ "score": "850", "time": "2:45", }, Tokens: registrationTokens, } br, err := client.SendMulticast(context.Background(), message) if err != nil { log.Fatalln(err) } // See the BatchResponse reference documentation // for the contents of response. fmt.Printf("%d messages were sent successfully\n", br.SuccessCount) // [END send_multicast] } func sendMulticastAndHandleErrors(ctx context.Context, client *messaging.Client) { // [START send_multicast_error] // Create a list containing up to 500 registration tokens. // This registration tokens come from the client FCM SDKs. registrationTokens := []string{ "YOUR_REGISTRATION_TOKEN_1", // ... "YOUR_REGISTRATION_TOKEN_n", } message := &messaging.MulticastMessage{ Data: map[string]string{ "score": "850", "time": "2:45", }, Tokens: registrationTokens, } br, err := client.SendMulticast(context.Background(), message) if err != nil { log.Fatalln(err) } if br.FailureCount > 0 { var failedTokens []string for idx, resp := range br.Responses { if !resp.Success { // The order of responses corresponds to the order of the registration tokens. failedTokens = append(failedTokens, registrationTokens[idx]) } } fmt.Printf("List of tokens that caused failures: %v\n", failedTokens) } // [END send_multicast_error] } func sendEachForMulticastAndHandleErrors(ctx context.Context, client *messaging.Client) { // [START send_each_for_multicast_error] // Create a list containing up to 500 registration tokens. // This registration tokens come from the client FCM SDKs. registrationTokens := []string{ "YOUR_REGISTRATION_TOKEN_1", // ... "YOUR_REGISTRATION_TOKEN_n", } message := &messaging.MulticastMessage{ Data: map[string]string{ "score": "850", "time": "2:45", }, Tokens: registrationTokens, } br, err := client.SendEachForMulticast(context.Background(), message) if err != nil { log.Fatalln(err) } if br.FailureCount > 0 { var failedTokens []string for idx, resp := range br.Responses { if !resp.Success { // The order of responses corresponds to the order of the registration tokens. failedTokens = append(failedTokens, registrationTokens[idx]) } } fmt.Printf("List of tokens that caused failures: %v\n", failedTokens) } // [END send_each_for_multicast_error] } func sendDryRun(ctx context.Context, client *messaging.Client) { message := &messaging.Message{ Data: map[string]string{ "score": "850", "time": "2:45", }, Token: "token", } // [START send_dry_run_golang] // Send a message in the dry run mode. response, err := client.SendDryRun(ctx, message) if err != nil { log.Fatalln(err) } // Response is a message ID string. fmt.Println("Dry run successful:", response) // [END send_dry_run_golang] } func androidMessage() *messaging.Message { // [START android_message_golang] oneHour := time.Duration(1) * time.Hour message := &messaging.Message{ Android: &messaging.AndroidConfig{ TTL: &oneHour, Priority: "normal", Notification: &messaging.AndroidNotification{ Title: "$GOOG up 1.43% on the day", Body: "$GOOG gained 11.80 points to close at 835.67, up 1.43% on the day.", Icon: "stock_ticker_update", Color: "#f45342", }, }, Topic: "industry-tech", } // [END android_message_golang] return message } func apnsMessage() *messaging.Message { // [START apns_message_golang] badge := 42 message := &messaging.Message{ APNS: &messaging.APNSConfig{ Headers: map[string]string{ "apns-priority": "10", }, Payload: &messaging.APNSPayload{ Aps: &messaging.Aps{ Alert: &messaging.ApsAlert{ Title: "$GOOG up 1.43% on the day", Body: "$GOOG gained 11.80 points to close at 835.67, up 1.43% on the day.", }, Badge: &badge, }, }, }, Topic: "industry-tech", } // [END apns_message_golang] return message } func webpushMessage() *messaging.Message { // [START webpush_message_golang] message := &messaging.Message{ Webpush: &messaging.WebpushConfig{ Notification: &messaging.WebpushNotification{ Title: "$GOOG up 1.43% on the day", Body: "$GOOG gained 11.80 points to close at 835.67, up 1.43% on the day.", Icon: "https://my-server/icon.png", }, }, Topic: "industry-tech", } // [END webpush_message_golang] return message } func allPlatformsMessage() *messaging.Message { // [START multi_platforms_message_golang] oneHour := time.Duration(1) * time.Hour badge := 42 message := &messaging.Message{ Notification: &messaging.Notification{ Title: "$GOOG up 1.43% on the day", Body: "$GOOG gained 11.80 points to close at 835.67, up 1.43% on the day.", }, Android: &messaging.AndroidConfig{ TTL: &oneHour, Notification: &messaging.AndroidNotification{ Icon: "stock_ticker_update", Color: "#f45342", }, }, APNS: &messaging.APNSConfig{ Payload: &messaging.APNSPayload{ Aps: &messaging.Aps{ Badge: &badge, }, }, }, Topic: "industry-tech", } // [END multi_platforms_message_golang] return message } func subscribeToTopic(ctx context.Context, client *messaging.Client) { topic := "highScores" // [START subscribe_golang] // These registration tokens come from the client FCM SDKs. registrationTokens := []string{ "YOUR_REGISTRATION_TOKEN_1", // ... "YOUR_REGISTRATION_TOKEN_n", } // Subscribe the devices corresponding to the registration tokens to the // topic. response, err := client.SubscribeToTopic(ctx, registrationTokens, topic) if err != nil { log.Fatalln(err) } // See the TopicManagementResponse reference documentation // for the contents of response. fmt.Println(response.SuccessCount, "tokens were subscribed successfully") // [END subscribe_golang] } func unsubscribeFromTopic(ctx context.Context, client *messaging.Client) { topic := "highScores" // [START unsubscribe_golang] // These registration tokens come from the client FCM SDKs. registrationTokens := []string{ "YOUR_REGISTRATION_TOKEN_1", // ... "YOUR_REGISTRATION_TOKEN_n", } // Unsubscribe the devices corresponding to the registration tokens from // the topic. response, err := client.UnsubscribeFromTopic(ctx, registrationTokens, topic) if err != nil { log.Fatalln(err) } // See the TopicManagementResponse reference documentation // for the contents of response. fmt.Println(response.SuccessCount, "tokens were unsubscribed successfully") // [END unsubscribe_golang] } golang-google-firebase-go-4.18.0/snippets/storage.go000066400000000000000000000037271505612111400223530ustar00rootroot00000000000000// Copyright 2017 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package snippets import ( "context" "log" firebase "firebase.google.com/go/v4" "google.golang.org/api/option" ) // ================================================================== // https://firebase.google.com/docs/storage/admin/start // ================================================================== func cloudStorage() { // [START cloud_storage_golang] config := &firebase.Config{ StorageBucket: ".appspot.com", } opt := option.WithCredentialsFile("path/to/serviceAccountKey.json") app, err := firebase.NewApp(context.Background(), config, opt) if err != nil { log.Fatalln(err) } client, err := app.Storage(context.Background()) if err != nil { log.Fatalln(err) } bucket, err := client.DefaultBucket() if err != nil { log.Fatalln(err) } // 'bucket' is an object defined in the cloud.google.com/go/storage package. // See https://godoc.org/cloud.google.com/go/storage#BucketHandle // for more details. // [END cloud_storage_golang] log.Printf("Created bucket handle: %v\n", bucket) } func cloudStorageCustomBucket(app *firebase.App) { client, err := app.Storage(context.Background()) if err != nil { log.Fatalln(err) } // [START cloud_storage_custom_bucket_golang] bucket, err := client.Bucket("my-custom-bucket") // [END cloud_storage_custom_bucket_golang] if err != nil { log.Fatalln(err) } log.Printf("Created bucket handle: %v\n", bucket) } golang-google-firebase-go-4.18.0/storage/000077500000000000000000000000001505612111400201465ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/storage/storage.go000066400000000000000000000040631505612111400221440ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package storage provides functions for accessing Google Cloud Storge buckets. package storage import ( "context" "errors" "os" "cloud.google.com/go/storage" "firebase.google.com/go/v4/internal" ) // Client is the interface for the Firebase Storage service. type Client struct { client *storage.Client bucket string } // NewClient creates a new instance of the Firebase Storage Client. // // This function can only be invoked from within the SDK. Client applications should access the // the Storage service through firebase.App. func NewClient(ctx context.Context, c *internal.StorageConfig) (*Client, error) { if os.Getenv("STORAGE_EMULATOR_HOST") == "" && os.Getenv("FIREBASE_STORAGE_EMULATOR_HOST") != "" { os.Setenv("STORAGE_EMULATOR_HOST", os.Getenv("FIREBASE_STORAGE_EMULATOR_HOST")) } client, err := storage.NewClient(ctx, c.Opts...) if err != nil { return nil, err } return &Client{client: client, bucket: c.Bucket}, nil } // DefaultBucket returns a handle to the default Cloud Storage bucket. // // To use this method, the default bucket name must be specified via firebase.Config when // initializing the App. func (c *Client) DefaultBucket() (*storage.BucketHandle, error) { return c.Bucket(c.bucket) } // Bucket returns a handle to the specified Cloud Storage bucket. func (c *Client) Bucket(name string) (*storage.BucketHandle, error) { if name == "" { return nil, errors.New("bucket name not specified") } return c.client.Bucket(name), nil } golang-google-firebase-go-4.18.0/storage/storage_test.go000066400000000000000000000055431505612111400232070ustar00rootroot00000000000000// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package storage import ( "context" "os" "testing" "firebase.google.com/go/v4/internal" "google.golang.org/api/option" ) var opts = []option.ClientOption{ option.WithCredentialsFile("../testdata/service_account.json"), } func TestNewClientError(t *testing.T) { invalid := []option.ClientOption{ option.WithCredentialsFile("../testdata/non_existing.json"), } client, err := NewClient(context.Background(), &internal.StorageConfig{ Opts: invalid, }) if client != nil || err == nil { t.Errorf("NewClient() = (%v, %v); want (nil, error)", client, err) } } func TestNewClientEmulatorHostEnvVar(t *testing.T) { emulatorHost := "localhost:9099" os.Setenv("FIREBASE_STORAGE_EMULATOR_HOST", emulatorHost) defer os.Unsetenv("FIREBASE_STORAGE_EMULATOR_HOST") os.Unsetenv("STORAGE_EMULATOR_HOST") defer os.Unsetenv("STORAGE_EMULATOR_HOST") _, err := NewClient(context.Background(), &internal.StorageConfig{ Opts: opts, }) if err != nil { t.Fatal(err) } if host := os.Getenv("STORAGE_EMULATOR_HOST"); host != emulatorHost { t.Errorf("emulator host: %q; want: %q", host, emulatorHost) } } func TestNoBucketName(t *testing.T) { client, err := NewClient(context.Background(), &internal.StorageConfig{ Opts: opts, }) if err != nil { t.Fatal(err) } if _, err := client.DefaultBucket(); err == nil { t.Errorf("DefaultBucket() = nil; want error") } } func TestEmptyBucketName(t *testing.T) { client, err := NewClient(context.Background(), &internal.StorageConfig{ Opts: opts, }) if err != nil { t.Fatal(err) } if _, err := client.Bucket(""); err == nil { t.Errorf("Bucket('') = nil; want error") } } func TestDefaultBucket(t *testing.T) { client, err := NewClient(context.Background(), &internal.StorageConfig{ Bucket: "bucket.name", Opts: opts, }) if err != nil { t.Fatal(err) } bucket, err := client.DefaultBucket() if bucket == nil || err != nil { t.Errorf("DefaultBucket() = (%v, %v); want: (bucket, nil)", bucket, err) } } func TestBucket(t *testing.T) { client, err := NewClient(context.Background(), &internal.StorageConfig{ Opts: opts, }) if err != nil { t.Fatal(err) } bucket, err := client.Bucket("bucket.name") if bucket == nil || err != nil { t.Errorf("Bucket() = (%v, %v); want: (bucket, nil)", bucket, err) } } golang-google-firebase-go-4.18.0/testdata/000077500000000000000000000000001505612111400203135ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/testdata/appcheck_pk.pem000066400000000000000000000032121505612111400232640ustar00rootroot00000000000000-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEArFYQyEdjj43mnpXwj+3WgAE01TSYe1+XFE9mxUDShysFwtVZ OHFSMm6kl+B3Y/O8NcPt5osntLlH6KHvygExAE0tDmFYq8aKt7LQQF8rTv0rI6MP 92ezyCEp4MPmAPFD/tY160XGrkqApuY2/+L8eEXdkRyH2H7lCYypFC0u3DIY25Vl q+ZDkxB2kGykGgb1zVazCDDViqV1p9hSltmm4el9AyF08FsMCpk/NvwKOY4pJ/sm 99CDKxMhQBaT9lrIQt0B1VqTpEwlOoiFiyXASRXp9ZTeL4mrLPqSeozwPvspD81w bgecd62F640scKBr3ko73L8M8UWcwgd+moKCJwIDAQABAoIBAEDPJQSMhE6KKL5e 2NbntJDy4zGC1A0hh6llqtpnZETc0w/QN/tX8ndw0IklKwD1ukPl6OOYVVhLjVVZ ANpQ1GKuo1ETHsuKoMQwhMyQfbL41m5SdkCuSRfsENmsEiUslkuRtzlBRlRpRDR/ wxM8A4IflBFsT1IFdpC+yx8BVuwLc35iVnaGQpo/jhSDibt07j+FdOKEWkMGj+rL sHC6cpB2NMTBl9CIDLW/eq1amBOAGtsSKqoGJvaQY/mZf7SPkRjYIfIl2PWSaduT fmMrsYYFtHUKVOMYAD7P5RWNkS8oERucnXT3ouAECvip3Ew2JqlQc0FP7FS5CxH3 WdfvLuECgYEA8Q7rJrDOdO867s7P/lXMklbAGnuNnAZJdAEXUMIaPJi7al97F119 4DKBuF7c/dDf8CdiOvMzP8r/F8+FFx2D61xxkQNeuxo5Xjlt23OzW5EI2S6ABesZ /3sQWqvKCGuqN7WENYF3EiKyByQ22MYXk8CE7KZuO57Aj88t6TsaNhkCgYEAtwSs hbqKSCneC1bQ3wfSAF2kPYRrQEEa2VCLlX1Mz7zHufxksUWAnAbU8O3hIGnXjz6T qzivyJJhFSgNGeYpwV67GfXnibpr3OZ/yx2YXIQfp0daivj++kvEU7aNfM9rHZA9 S3Gh7hKELdB9b0DkrX5GpLiZWA6NnJdrIRYbAj8CgYBCZSyJvJsxBA+EZTxOvk0Z ZYGGCc/oUKb8p6xHVx8o35yHYQMjXWHlVaP7J03RLy3vFLnuqLvN71ixszviMQP7 2LuDCJ2YBVIVzNWgY07cgqcgQrmKZ8YCY2AOyVBdX2JD8+AVaLJmMV49r1DYBj/K N3WlRPYJv+Ej+xmXKus+SQKBgHh/Zkthxxu+HQigL0M4teYxwSoTnj2e39uGsXBK ICGCLIniiDVDCmswAFFkfV3G8frI+5a26t2Gqs6wIPgVVxaOlWeBROGkUNIPHMKR iLgY8XJEg3OOfuoyql9niP5M3jyHtCOQ/Elv/YDgjUWLl0Q3KLHZLHUSl+AqvYj6 MewnAoGBANgYzPZgP+wreI55BFR470blKh1mFz+YGa+53DCd7JdMH2pdp4hoh303 XxpOSVlAuyv9SgTsZ7WjGO5UdhaBzVPKgN0OO6JQmQ5ZrOR8ZJ7VB73FiVHCEerj 1m2zyFv6OT7vqdg+V1/SzxMEmXXFQv1g69k6nWGazne3IJlzrSpj -----END RSA PRIVATE KEY-----golang-google-firebase-go-4.18.0/testdata/dinosaurs.json000066400000000000000000000027651505612111400232270ustar00rootroot00000000000000{ "dinosaurs": { "bruhathkayosaurus": { "appeared": -70000000, "height": 25, "length": 44, "order": "saurischia", "vanished": -70000000, "weight": 135000, "ratings": { "pos": 1 } }, "lambeosaurus": { "appeared": -76000000, "height": 2.1, "length": 12.5, "order": "ornithischia", "vanished": -75000000, "weight": 5000, "ratings": { "pos": 2 } }, "linhenykus": { "appeared": -85000000, "height": 0.6, "length": 1, "order": "theropoda", "vanished": -75000000, "weight": 3, "ratings": { "pos": 3 } }, "pterodactyl": { "appeared": -150000000, "height": 0.6, "length": 0.8, "order": "pterosauria", "vanished": -148500000, "weight": 2, "ratings": { "pos": 4 } }, "stegosaurus": { "appeared": -155000000, "height": 4, "length": 9, "order": "ornithischia", "vanished": -150000000, "weight": 2500, "ratings": { "pos": 5 } }, "triceratops": { "appeared": -68000000, "height": 3, "length": 8, "order": "ornithischia", "vanished": -66000000, "weight": 11000, "ratings": { "pos": 6 } } }, "scores": { "bruhathkayosaurus": 55, "lambeosaurus": 21, "linhenykus": 80, "pterodactyl": 93, "stegosaurus": 5, "triceratops": 22 } } golang-google-firebase-go-4.18.0/testdata/dinosaurs_index.json000066400000000000000000000013761505612111400244130ustar00rootroot00000000000000{ "rules": { "_adminsdk": { "go": { "dinodb": { "dinosaurs": { ".indexOn": ["height", "ratings/pos"] }, "scores": { ".indexOn": ".value" } }, "protected": { "$uid": { ".read": "auth != null", ".write": "$uid === auth.uid" } }, "admin": { ".read": "false", ".write": "false" }, "public": { ".read": "true" } } } } }golang-google-firebase-go-4.18.0/testdata/firebase_config.json000066400000000000000000000002141505612111400243100ustar00rootroot00000000000000{ "databaseURL": "https://auto-init.database.url", "projectId": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" } golang-google-firebase-go-4.18.0/testdata/firebase_config_empty.json000066400000000000000000000000001505612111400255170ustar00rootroot00000000000000golang-google-firebase-go-4.18.0/testdata/firebase_config_invalid.json000066400000000000000000000000061505612111400260150ustar00rootroot00000000000000baaad golang-google-firebase-go-4.18.0/testdata/firebase_config_invalid_key.json000066400000000000000000000001411505612111400266650ustar00rootroot00000000000000{ "project1d_bad_key": "auto-init-project-id", "storageBucket": "auto-init.storage.bucket" } golang-google-firebase-go-4.18.0/testdata/firebase_config_partial.json000066400000000000000000000000521505612111400260240ustar00rootroot00000000000000{ "projectId": "auto-init-project-id" } golang-google-firebase-go-4.18.0/testdata/get_disabled_user.json000066400000000000000000000012171505612111400246530ustar00rootroot00000000000000{ "kind": "identitytoolkit#GetAccountInfoResponse", "users": [ { "localId": "testuser", "email": "testuser@example.com", "phoneNumber": "+1234567890", "emailVerified": true, "displayName": "Test User", "photoUrl": "http://www.example.com/testuser/photo.png", "passwordHash": "passwordhash", "salt": "salt===", "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": true, "createdAt": "1234567890000", "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}", "tenantId": "testTenant" } ] } golang-google-firebase-go-4.18.0/testdata/get_user.json000066400000000000000000000027631505612111400230330ustar00rootroot00000000000000{ "kind": "identitytoolkit#GetAccountInfoResponse", "users": [ { "localId": "testuser", "email": "testuser@example.com", "phoneNumber": "+1234567890", "emailVerified": true, "displayName": "Test User", "providerUserInfo": [ { "providerId": "password", "displayName": "Test User", "photoUrl": "http://www.example.com/testuser/photo.png", "federatedId": "testuser@example.com", "email": "testuser@example.com", "rawId": "testuid" }, { "providerId": "phone", "phoneNumber": "+1234567890", "rawId": "testuid" } ], "photoUrl": "http://www.example.com/testuser/photo.png", "passwordHash": "passwordhash", "salt": "salt===", "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, "createdAt": "1234567890000", "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}", "tenantId": "testTenant", "mfaInfo": [ { "phoneInfo": "+1234567890", "mfaEnrollmentId": "enrolledPhoneFactor", "displayName": "My MFA Phone", "enrolledAt": "2021-03-03T13:06:20.542896Z" }, { "totpInfo": {}, "mfaEnrollmentId": "enrolledTOTPFactor", "displayName": "My MFA TOTP", "enrolledAt": "2021-03-03T13:06:20.542896Z" } ] } ] } golang-google-firebase-go-4.18.0/testdata/invalid_service_account.json000066400000000000000000000041721505612111400260740ustar00rootroot00000000000000{ "type": "service_account", "project_id": "mock-project-id", "private_key_id": "mock-key-id-1", "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAwJENcRev+eXZKvhhWLiV3Lz2MvO+naQRHo59g3vaNQnbgyduN/L4krlr\nJ5c6FiikXdtJNb/QrsAHSyJWCu8j3T9CruiwbidGAk2W0RuViTVspjHUTsIHExx9euWM0Uom\nGvYkoqXahdhPL/zViVSJt+Rt8bHLsMvpb8RquTIb9iKY3SMV2tCofNmyCSgVbghq/y7lKORt\nV/IRguWs6R22fbkb0r2MCYoNAbZ9dqnbRIFNZBC7itYtUoTEresRWcyFMh0zfAIJycWOJlVL\nDLqkY2SmIx8u7fuysCg1wcoSZoStuDq02nZEMw1dx8HGzE0hynpHlloRLByuIuOAfMCCYwID\nAQABAoIBADFtihu7TspAO0wSUTpqttzgC/nsIsNn95T2UjVLtyjiDNxPZLUrwq42tdCFur0x\nVW9Z+CK5x6DzXWvltlw8IeKKeF1ZEOBVaFzy+YFXKTz835SROcO1fgdjyrme7lRSShGlmKW/\nGKY+baUNquoDLw5qreXaE0SgMp0jt5ktyYuVxvhLDeV4omw2u6waoGkifsGm8lYivg5l3VR7\nw2IVOvYZTt4BuSYVwOM+qjwaS1vtL7gv0SUjrj85Ja6zERRdFiITDhZw6nsvacr9/+/aut9E\naL/koSSb62g5fntQMEwoT4hRnjPnAedmorM9Rhddh2TB3ZKTBbMN1tUk3fJxOuECgYEA+z6l\neSaAcZ3qvwpntcXSpwwJ0SSmzLTH2RJNf+Ld3eBHiSvLTG53dWB7lJtF4R1KcIwf+KGcOFJv\nsnepzcZBylRvT8RrAAkV0s9OiVm1lXZyaepbLg4GGFJBPi8A6VIAj7zYknToRApdW0s1x/XX\nChewfJDckqsevTMovdbg8YkCgYEAxDYX+3mfvv/opo6HNNY3SfVunM+4vVJL+n8gWZ2w9kz3\nQ9Ub9YbRmI7iQaiVkO5xNuoG1n9bM+3Mnm84aQ1YeNT01YqeyQsipP5Wi+um0PzYTaBw9RO+\n8Gh6992OwlJiRtFk5WjalNWOxY4MU0ImnJwIfKQlUODvLmcixm68NYsCgYEAuAqI3jkk55Vd\nKvotREsX5wP7gPePM+7NYiZ1HNQL4Ab1f/bTojZdTV8Sx6YCR0fUiqMqnE+OBvfkGGBtw22S\nLesx6sWf99Ov58+x4Q0U5dpxL0Lb7d2Z+2Dtp+Z4jXFjNeeI4ae/qG/LOR/b0pE0J5F415ap\n7Mpq5v89vepUtrkCgYAjMXytu4v+q1Ikhc4UmRPDrUUQ1WVSd+9u19yKlnFGTFnRjej86hiw\nH3jPxBhHra0a53EgiilmsBGSnWpl1WH4EmJz5vBCKUAmjgQiBrueIqv9iHiaTNdjsanUyaWw\njyxXfXl2eI80QPXh02+8g1H/pzESgjK7Rg1AqnkfVH9nrwKBgQDJVxKBPTw9pigYMVt9iHrR\niCl9zQVjRMbWiPOc0J56+/5FZYm/AOGl9rfhQ9vGxXZYZiOP5FsNkwt05Y1UoAAH4B4VQwbL\nqod71qOcI0ywgZiIR87CYw40gzRfjWnN+YEEW1qfyoNLilEwJB8iB/T+ZePHGmJ4MmQ/cTn9\nxpdLXA==\n-----END RSA PRIVATE KEY-----", "client_id": "1234567890", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://accounts.google.com/o/oauth2/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/mock-project-id.iam.gserviceaccount.com" } golang-google-firebase-go-4.18.0/testdata/list_users.json000066400000000000000000000116431505612111400234070ustar00rootroot00000000000000{ "kind": "identitytoolkit#DownloadAccountResponse", "users": [ { "localId": "testuser", "email": "testuser@example.com", "phoneNumber": "+1234567890", "emailVerified": true, "displayName": "Test User", "providerUserInfo": [ { "providerId": "password", "displayName": "Test User", "photoUrl": "http://www.example.com/testuser/photo.png", "federatedId": "testuser@example.com", "email": "testuser@example.com", "rawId": "testuid" }, { "providerId": "phone", "phoneNumber": "+1234567890", "rawId": "testuid" } ], "photoUrl": "http://www.example.com/testuser/photo.png", "passwordHash": "passwordhash1", "salt": "salt1", "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, "createdAt": "1234567890000", "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}", "tenantId": "testTenant", "mfaInfo": [ { "phoneInfo": "+1234567890", "mfaEnrollmentId": "enrolledPhoneFactor", "displayName": "My MFA Phone", "enrolledAt": "2021-03-03T13:06:20.542896Z" }, { "totpInfo": {}, "mfaEnrollmentId": "enrolledTOTPFactor", "displayName": "My MFA TOTP", "enrolledAt": "2021-03-03T13:06:20.542896Z" } ] }, { "localId": "testuser", "email": "testuser@example.com", "phoneNumber": "+1234567890", "emailVerified": true, "displayName": "Test User", "providerUserInfo": [ { "providerId": "password", "displayName": "Test User", "photoUrl": "http://www.example.com/testuser/photo.png", "federatedId": "testuser@example.com", "email": "testuser@example.com", "rawId": "testuid" }, { "providerId": "phone", "phoneNumber": "+1234567890", "rawId": "testuid" } ], "photoUrl": "http://www.example.com/testuser/photo.png", "passwordHash": "passwordhash2", "salt": "salt2", "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, "createdAt": "1234567890000", "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}", "tenantId": "testTenant", "mfaInfo": [ { "phoneInfo": "+1234567890", "mfaEnrollmentId": "enrolledPhoneFactor", "displayName": "My MFA Phone", "enrolledAt": "2021-03-03T13:06:20.542896Z" }, { "totpInfo": {}, "mfaEnrollmentId": "enrolledTOTPFactor", "displayName": "My MFA TOTP", "enrolledAt": "2021-03-03T13:06:20.542896Z" } ] }, { "localId": "testusernomfa", "email": "testusernomfa@example.com", "phoneNumber": "+1234567890", "emailVerified": true, "displayName": "Test User Without MFA", "providerUserInfo": [ { "providerId": "password", "displayName": "Test User Without MFA", "photoUrl": "http://www.example.com/testusernomfa/photo.png", "federatedId": "testusernomfa@example.com", "email": "testusernomfa@example.com", "rawId": "testuid" }, { "providerId": "phone", "phoneNumber": "+1234567890", "rawId": "testuid" } ], "photoUrl": "http://www.example.com/testusernomfa/photo.png", "passwordHash": "passwordhash3", "salt": "salt3", "passwordUpdatedAt": 1.494364393E+12, "validSince": "1494364393", "disabled": false, "createdAt": "1234567890000", "lastLoginAt": "1233211232000", "customAttributes": "{\"admin\": true, \"package\": \"gold\"}", "tenantId": "testTenant" } ], "nextPageToken": "" } golang-google-firebase-go-4.18.0/testdata/mock.jwks.json000066400000000000000000000010461505612111400231150ustar00rootroot00000000000000{ "keys": [ { "kty": "RSA", "e": "AQAB", "use": "sig", "kid": "FGQdnRlzAmKyKr6-Hg_kMQrBkj_H6i6ADnBQz4OI6BU", "alg": "RS256", "n": "rFYQyEdjj43mnpXwj-3WgAE01TSYe1-XFE9mxUDShysFwtVZOHFSMm6kl-B3Y_O8NcPt5osntLlH6KHvygExAE0tDmFYq8aKt7LQQF8rTv0rI6MP92ezyCEp4MPmAPFD_tY160XGrkqApuY2_-L8eEXdkRyH2H7lCYypFC0u3DIY25Vlq-ZDkxB2kGykGgb1zVazCDDViqV1p9hSltmm4el9AyF08FsMCpk_NvwKOY4pJ_sm99CDKxMhQBaT9lrIQt0B1VqTpEwlOoiFiyXASRXp9ZTeL4mrLPqSeozwPvspD81wbgecd62F640scKBr3ko73L8M8UWcwgd-moKCJw" } ] }golang-google-firebase-go-4.18.0/testdata/plain_text.txt000066400000000000000000000000051505612111400232160ustar00rootroot00000000000000Test golang-google-firebase-go-4.18.0/testdata/public_certs.json000066400000000000000000000074211505612111400236700ustar00rootroot00000000000000{ "mock-key-id-1": "-----BEGIN CERTIFICATE-----\nMIIEFTCCAv2gAwIBAgIJALLYfi2oN8cPMA0GCSqGSIb3DQEBCwUAMIGgMQswCQYD\nVQQGEwJVUzELMAkGA1UECAwCQ0ExFjAUBgNVBAcMDU1vdW50YWluIFZpZXcxDzAN\nBgNVBAoMBkdvb2dsZTERMA8GA1UECwwIRmlyZWJhc2UxHDAaBgNVBAMME2ZpcmVi\nYXNlLmdvb2dsZS5jb20xKjAoBgkqhkiG9w0BCQEWG3N1cHBvcnRAZmlyZWJhc2Uu\nZ29vZ2xlLmNvbTAeFw0xNzAzMjIwMDM4MzRaFw0yNzAzMjAwMDM4MzRaMIGgMQsw\nCQYDVQQGEwJVUzELMAkGA1UECAwCQ0ExFjAUBgNVBAcMDU1vdW50YWluIFZpZXcx\nDzANBgNVBAoMBkdvb2dsZTERMA8GA1UECwwIRmlyZWJhc2UxHDAaBgNVBAMME2Zp\ncmViYXNlLmdvb2dsZS5jb20xKjAoBgkqhkiG9w0BCQEWG3N1cHBvcnRAZmlyZWJh\nc2UuZ29vZ2xlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMCR\nDXEXr/nl2Sr4YVi4ldy89jLzvp2kER6OfYN72jUJ24Mnbjfy+JK5ayeXOhYopF3b\nSTW/0K7AB0siVgrvI90/Qq7osG4nRgJNltEblYk1bKYx1E7CBxMcfXrljNFKJhr2\nJKKl2oXYTy/81YlUibfkbfGxy7DL6W/EarkyG/YimN0jFdrQqHzZsgkoFW4Iav8u\n5SjkbVfyEYLlrOkdtn25G9K9jAmKDQG2fXap20SBTWQQu4rWLVKExK3rEVnMhTId\nM3wCCcnFjiZVSwy6pGNkpiMfLu37srAoNcHKEmaErbg6tNp2RDMNXcfBxsxNIcp6\nR5ZaESwcriLjgHzAgmMCAwEAAaNQME4wHQYDVR0OBBYEFGmG5dc2YEEDbFA2+SBS\nA13S5l4VMB8GA1UdIwQYMBaAFGmG5dc2YEEDbFA2+SBSA13S5l4VMAwGA1UdEwQF\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAAEmICKB6kq/Y++JKHZg88JS4nlWzIFh\nNBrfyCnMQiL9mmllEXQIhK25xleQwQGsBF2odDj+8H9CG/lwWLmyC5+TryFjWrhn\nHlt8QJb8E4dIZkYAxDL/ii6tXfFTjvrXsTcY2moD6ZoOoxahVOjVfwkHup0ONn2v\nsCL/11FneR0jhgruXKoqrKspgNVuYp+t4IKnnePpeGJb/I3SyS9GUXlScV/uWyRw\nLdIoR2teEWcWeNrMLmth0NSa3AF3gd9+HTaGpESsusG4qPamqiSM7+INAeTo4k8b\nlbqLwo3Ju6cNGGlDSsDXIUahpCdKnqxBALytITmIcHwsR4vYaDP4iOE=\n-----END CERTIFICATE-----", "mock-key-id-2": "-----BEGIN CERTIFICATE-----\nMIIDKjCCAhKgAwIBAgIIBIUnv7pTIx8wDQYJKoZIhvcNAQEFBQAwODE2MDQGA1UE\nAxMtdGVzdC00ODQubWctdGVzdC0xMjEwLmlhbS5nc2VydmljZWFjY291bnQuY29t\nMB4XDTE2MDMxOTE3NTE1NFoXDTE2MDMyMTA2NTE1NFowODE2MDQGA1UEAxMtdGVz\ndC00ODQubWctdGVzdC0xMjEwLmlhbS5nc2VydmljZWFjY291bnQuY29tMIIBIjAN\nBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA7beJFmTrA/T4AeMWk/IjxUlGpaxH\n6D1CYbfxEBJUqzuIe7ujaxh76ik/FPQV5WxlL1GOjW0/f5CsmrNaFmTmQbsK4BY3\n3cCd3gM8LcEtmF1I9NxxpXxrZihlfuwbEpb5NpjGPkCC+fG3gTY7qtjuO6e8pGb2\nVQQguOGXKw/YZLZRZXZ41xkQRYrs+tFw48+4YkjMsYJIxyBMiL5Q/HNAQ2IUyZwr\nuc+CMcWyPLNcnsRNXgnPXQD/GKZQnjjJ5KzQAU1vnDcufL9V5KRhb0kRxTTUjE7D\nJl3x4+J6+hbAheZFu9Fntrxie9TvQuQbEBm/437QFYZphfQli0fDjlPHSwIDAQAB\nozgwNjAMBgNVHRMBAf8EAjAAMA4GA1UdDwEB/wQEAwIHgDAWBgNVHSUBAf8EDDAK\nBggrBgEFBQcDAjANBgkqhkiG9w0BAQUFAAOCAQEAQzlUGQiWiHgeBZyUsetuoTiQ\nsxzU7B1qw3la/FQrG+jRFr9GE3yjOOxi9JvX16U/ebwSHLUip8UFf/Ir6AJlt/tt\nIjBA6TOd8DysAtr4PCZrAP/m43H9w4lBWdWl1XJE2YfYQgZnorveAMUZqTo0P0pd\nFo3IsYBSTMflKv2Vqz91PPiHgyu2fk+8TYwJT57rnnkS6VzdORTIf+9ZB+J1ye9i\nQN5IgdZ/eqFiJPD8qT5jOcXelWSWqHHdGrNjQNp+z8jgMusY5/ZAlZUe55eo3I0m\nuDSPImLNkDwqY0+bBW6Fp5xi/4O+gJg3cQ+/PeIHzoFqKAlSpxQZSCziPpGfAA==\n-----END CERTIFICATE-----\n", "mock-key-id-3": "-----BEGIN CERTIFICATE-----\nMIIC+jCCAeKgAwIBAgIIRKlYUHIlbRkwDQYJKoZIhvcNAQEFBQAwIDEeMBwGA1UE\nAxMVMTAwNzcyNzQyMjQ5MTUwNTc4MjYyMB4XDTE2MDIxMDAxMzI1N1oXDTI2MDIw\nNzAxMzI1N1owIDEeMBwGA1UEAxMVMTAwNzcyNzQyMjQ5MTUwNTc4MjYyMIIBIjAN\nBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAkIqM1JR3h61LDUN+FaTDFnwQkrVD\nG9qDy54QNSxElxgEPRyGhc4KAz7cRTMSWfoCWYkYaW/nZkqTfWlhsawZQU1oK8pj\nxJhUQHYpTISA4DTLmz5R4ng9mVGMqZWaCs4oiPvdizyrxus6RIdQ5bRbZyhl4pzn\n23E8tszCxnFX4KYneyvLtbcXoEvSezhm6n6yT4bNzSZguKxOSZU3XNFPmcBVjYPN\njA2aWzE54uu0ve+JrBVkRq/3XB/OvvJyjIovdxnzF91YJ5KukHJMhZqnQRIc3VcN\njx5+WwSQKE2klR7wB0AsAIs0lxA4wxH/+EUCHM2S5lfrMidz7cOXmMmeJwIDAQAB\nozgwNjAMBgNVHRMBAf8EAjAAMA4GA1UdDwEB/wQEAwIHgDAWBgNVHSUBAf8EDDAK\nBggrBgEFBQcDAjANBgkqhkiG9w0BAQUFAAOCAQEAF+V8kmeJQnvpPKlFT74BROi0\n1Eple2mSsyQbtm1kL7FJpl1AXZ4sLXXTVj3ql0LsqVawDCVtUSvDXBSHtejnh0bi\nZ0WUyEEJ38XPfXRilIaTrYP408ezowDaXxrfLhho1EjoMOPgXjksu1FyhBFoHmif\ndLJoxyA4f+8DZ8jj7ew6ZIVEmvONYgctpU72uUh36Vyl84oc9D2GByq/zYDXvVvl\nSKWYZ5+86/eGocO4sosB5QrsEdVGT2Im6mz2DUIewSyIvrDgZ5r3XyL4RXpdi8+8\n9re/meIh5pnhimU4pX9weQia8bqSPf0oZhh0uAWxO5ES7k1GwblnJfxeCZ0xDQ==\n-----END CERTIFICATE-----\n" } golang-google-firebase-go-4.18.0/testdata/refresh_token.json000066400000000000000000000002351505612111400240440ustar00rootroot00000000000000{ "type": "authorized_user", "client_id": "mock.apps.googleusercontent.com", "client_secret": "mock-secret", "refresh_token": "mock-refresh-token" } golang-google-firebase-go-4.18.0/testdata/service_account.json000066400000000000000000000042761505612111400243730ustar00rootroot00000000000000{ "type": "service_account", "project_id": "mock-project-id", "private_key_id": "mock-key-id-1", "private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAwJENcRev+eXZKvhhWLiV3Lz2MvO+naQRHo59g3vaNQnbgyduN/L4krlr\nJ5c6FiikXdtJNb/QrsAHSyJWCu8j3T9CruiwbidGAk2W0RuViTVspjHUTsIHExx9euWM0Uom\nGvYkoqXahdhPL/zViVSJt+Rt8bHLsMvpb8RquTIb9iKY3SMV2tCofNmyCSgVbghq/y7lKORt\nV/IRguWs6R22fbkb0r2MCYoNAbZ9dqnbRIFNZBC7itYtUoTEresRWcyFMh0zfAIJycWOJlVL\nDLqkY2SmIx8u7fuysCg1wcoSZoStuDq02nZEMw1dx8HGzE0hynpHlloRLByuIuOAfMCCYwID\nAQABAoIBADFtihu7TspAO0wSUTpqttzgC/nsIsNn95T2UjVLtyjiDNxPZLUrwq42tdCFur0x\nVW9Z+CK5x6DzXWvltlw8IeKKeF1ZEOBVaFzy+YFXKTz835SROcO1fgdjyrme7lRSShGlmKW/\nGKY+baUNquoDLw5qreXaE0SgMp0jt5ktyYuVxvhLDeV4omw2u6waoGkifsGm8lYivg5l3VR7\nw2IVOvYZTt4BuSYVwOM+qjwaS1vtL7gv0SUjrj85Ja6zERRdFiITDhZw6nsvacr9/+/aut9E\naL/koSSb62g5fntQMEwoT4hRnjPnAedmorM9Rhddh2TB3ZKTBbMN1tUk3fJxOuECgYEA+z6l\neSaAcZ3qvwpntcXSpwwJ0SSmzLTH2RJNf+Ld3eBHiSvLTG53dWB7lJtF4R1KcIwf+KGcOFJv\nsnepzcZBylRvT8RrAAkV0s9OiVm1lXZyaepbLg4GGFJBPi8A6VIAj7zYknToRApdW0s1x/XX\nChewfJDckqsevTMovdbg8YkCgYEAxDYX+3mfvv/opo6HNNY3SfVunM+4vVJL+n8gWZ2w9kz3\nQ9Ub9YbRmI7iQaiVkO5xNuoG1n9bM+3Mnm84aQ1YeNT01YqeyQsipP5Wi+um0PzYTaBw9RO+\n8Gh6992OwlJiRtFk5WjalNWOxY4MU0ImnJwIfKQlUODvLmcixm68NYsCgYEAuAqI3jkk55Vd\nKvotREsX5wP7gPePM+7NYiZ1HNQL4Ab1f/bTojZdTV8Sx6YCR0fUiqMqnE+OBvfkGGBtw22S\nLesx6sWf99Ov58+x4Q0U5dpxL0Lb7d2Z+2Dtp+Z4jXFjNeeI4ae/qG/LOR/b0pE0J5F415ap\n7Mpq5v89vepUtrkCgYAjMXytu4v+q1Ikhc4UmRPDrUUQ1WVSd+9u19yKlnFGTFnRjej86hiw\nH3jPxBhHra0a53EgiilmsBGSnWpl1WH4EmJz5vBCKUAmjgQiBrueIqv9iHiaTNdjsanUyaWw\njyxXfXl2eI80QPXh02+8g1H/pzESgjK7Rg1AqnkfVH9nrwKBgQDJVxKBPTw9pigYMVt9iHrR\niCl9zQVjRMbWiPOc0J56+/5FZYm/AOGl9rfhQ9vGxXZYZiOP5FsNkwt05Y1UoAAH4B4VQwbL\nqod71qOcI0ywgZiIR87CYw40gzRfjWnN+YEEW1qfyoNLilEwJB8iB/T+ZePHGmJ4MmQ/cTn9\nxpdLXA==\n-----END RSA PRIVATE KEY-----", "client_email": "mock-email@mock-project.iam.gserviceaccount.com", "client_id": "1234567890", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://accounts.google.com/o/oauth2/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/mock-project-id.iam.gserviceaccount.com" }