coyotte508 HF staff commited on
Commit
767afa7
Β·
unverified Β·
1 Parent(s): d9a878f

πŸ› Fix "signin with HF" within space + CSRF (#236)

Browse files
.env CHANGED
@@ -46,7 +46,8 @@ MODELS=`[
46
  ]`
47
  OLD_MODELS=`[]`# any removed models, `{ name: string, displayName?: string, id?: string }`
48
 
49
- PUBLIC_ORIGIN=#https://hf.co
 
50
  PUBLIC_GOOGLE_ANALYTICS_ID=#G-XXXXXXXX / Leave empty to disable
51
  PUBLIC_DEPRECATED_GOOGLE_ANALYTICS_ID=#UA-XXXXXXXX-X / Leave empty to disable
52
  PUBLIC_ANNOUNCEMENT_BANNERS=`[
 
46
  ]`
47
  OLD_MODELS=`[]`# any removed models, `{ name: string, displayName?: string, id?: string }`
48
 
49
+ PUBLIC_ORIGIN=#https://huggingface.co
50
+ PUBLIC_SHARE_PREFIX=#https://hf.co/chat
51
  PUBLIC_GOOGLE_ANALYTICS_ID=#G-XXXXXXXX / Leave empty to disable
52
  PUBLIC_DEPRECATED_GOOGLE_ANALYTICS_ID=#UA-XXXXXXXX-X / Leave empty to disable
53
  PUBLIC_ANNOUNCEMENT_BANNERS=`[
package-lock.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
  "name": "chat-ui",
3
- "version": "0.1.0",
4
  "lockfileVersion": 3,
5
  "requires": true,
6
  "packages": {
7
  "": {
8
  "name": "chat-ui",
9
- "version": "0.1.0",
10
  "dependencies": {
11
  "@huggingface/hub": "^0.5.1",
12
  "@huggingface/inference": "^2.2.0",
 
1
  {
2
  "name": "chat-ui",
3
+ "version": "0.2.0",
4
  "lockfileVersion": 3,
5
  "requires": true,
6
  "packages": {
7
  "": {
8
  "name": "chat-ui",
9
+ "version": "0.2.0",
10
  "dependencies": {
11
  "@huggingface/hub": "^0.5.1",
12
  "@huggingface/inference": "^2.2.0",
src/hooks.server.ts CHANGED
@@ -3,6 +3,7 @@ import type { Handle } from "@sveltejs/kit";
3
  import {
4
  PUBLIC_GOOGLE_ANALYTICS_ID,
5
  PUBLIC_DEPRECATED_GOOGLE_ANALYTICS_ID,
 
6
  } from "$env/static/public";
7
  import { collections } from "$lib/server/database";
8
  import { base } from "$app/paths";
@@ -20,25 +21,50 @@ export const handle: Handle = async ({ event, resolve }) => {
20
  event.locals.user = user;
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  if (
24
  !event.url.pathname.startsWith(`${base}/login`) &&
25
  !event.url.pathname.startsWith(`${base}/admin`) &&
26
  !["GET", "OPTIONS", "HEAD"].includes(event.request.method)
27
  ) {
28
- const sendJson =
29
- event.request.headers.get("accept")?.includes("application/json") ||
30
- event.request.headers.get("content-type")?.includes("application/json");
31
-
32
  if (!user && requiresUser) {
33
- return new Response(
34
- sendJson ? JSON.stringify({ error: ERROR_MESSAGES.authOnly }) : ERROR_MESSAGES.authOnly,
35
- {
36
- status: 401,
37
- headers: {
38
- "content-type": sendJson ? "application/json" : "text/plain",
39
- },
40
- }
41
- );
42
  }
43
 
44
  // if login is not required and the call is not from /settings, we check if the user has accepted the ethics modal first.
@@ -50,17 +76,7 @@ export const handle: Handle = async ({ event, resolve }) => {
50
  });
51
 
52
  if (!hasAcceptedEthicsModal) {
53
- return new Response(
54
- sendJson
55
- ? JSON.stringify({ error: "You need to accept the welcome modal first" })
56
- : "You need to accept the welcome modal first",
57
- {
58
- status: 405,
59
- headers: {
60
- "content-type": sendJson ? "application/json" : "text/plain",
61
- },
62
- }
63
- );
64
  }
65
  }
66
  }
 
3
  import {
4
  PUBLIC_GOOGLE_ANALYTICS_ID,
5
  PUBLIC_DEPRECATED_GOOGLE_ANALYTICS_ID,
6
+ PUBLIC_ORIGIN,
7
  } from "$env/static/public";
8
  import { collections } from "$lib/server/database";
9
  import { base } from "$app/paths";
 
21
  event.locals.user = user;
22
  }
23
 
24
+ function errorResponse(status: number, message: string) {
25
+ const sendJson =
26
+ event.request.headers.get("accept")?.includes("application/json") ||
27
+ event.request.headers.get("content-type")?.includes("application/json");
28
+ return new Response(sendJson ? JSON.stringify({ error: message }) : message, {
29
+ status,
30
+ headers: {
31
+ "content-type": sendJson ? "application/json" : "text/plain",
32
+ },
33
+ });
34
+ }
35
+
36
+ // CSRF protection
37
+ const requestContentType = event.request.headers.get("content-type")?.split(";")[0] ?? "";
38
+ /** https://developer.mozilla.org/en-US/docs/Web/HTML/Element/form#attr-enctype */
39
+ const nativeFormContentTypes = [
40
+ "multipart/form-data",
41
+ "application/x-www-form-urlencoded",
42
+ "text/plain",
43
+ ];
44
+ if (event.request.method === "POST" && nativeFormContentTypes.includes(requestContentType)) {
45
+ const referer = event.request.headers.get("referer");
46
+
47
+ if (!referer) {
48
+ return errorResponse(403, "Non-JSON form requests need to have a referer");
49
+ }
50
+
51
+ const validOrigins = [
52
+ new URL(event.request.url).origin,
53
+ ...(PUBLIC_ORIGIN ? [new URL(PUBLIC_ORIGIN).origin] : []),
54
+ ];
55
+
56
+ if (!validOrigins.includes(new URL(referer).origin)) {
57
+ return errorResponse(403, "Invalid referer for POST request");
58
+ }
59
+ }
60
+
61
  if (
62
  !event.url.pathname.startsWith(`${base}/login`) &&
63
  !event.url.pathname.startsWith(`${base}/admin`) &&
64
  !["GET", "OPTIONS", "HEAD"].includes(event.request.method)
65
  ) {
 
 
 
 
66
  if (!user && requiresUser) {
67
+ return errorResponse(401, ERROR_MESSAGES.authOnly);
 
 
 
 
 
 
 
 
68
  }
69
 
70
  // if login is not required and the call is not from /settings, we check if the user has accepted the ethics modal first.
 
76
  });
77
 
78
  if (!hasAcceptedEthicsModal) {
79
+ return errorResponse(405, "You need to accept the welcome modal first");
 
 
 
 
 
 
 
 
 
 
80
  }
81
  }
82
  }
src/lib/components/LoginModal.svelte CHANGED
@@ -1,5 +1,5 @@
1
  <script lang="ts">
2
- import { enhance } from "$app/forms";
3
  import { base } from "$app/paths";
4
  import { page } from "$app/stores";
5
  import { PUBLIC_VERSION } from "$env/static/public";
@@ -9,6 +9,8 @@
9
  import type { LayoutData } from "../../routes/$types";
10
 
11
  export let settings: LayoutData["settings"];
 
 
12
  </script>
13
 
14
  <Modal>
@@ -35,7 +37,7 @@
35
  </p>
36
  <form
37
  action="{base}/{$page.data.requiresLogin ? 'login' : 'settings'}"
38
- use:enhance
39
  method="POST"
40
  >
41
  {#if $page.data.requiresLogin}
 
1
  <script lang="ts">
2
+ import { browser } from "$app/environment";
3
  import { base } from "$app/paths";
4
  import { page } from "$app/stores";
5
  import { PUBLIC_VERSION } from "$env/static/public";
 
9
  import type { LayoutData } from "../../routes/$types";
10
 
11
  export let settings: LayoutData["settings"];
12
+
13
+ const isIframe = browser && window.self !== window.parent;
14
  </script>
15
 
16
  <Modal>
 
37
  </p>
38
  <form
39
  action="{base}/{$page.data.requiresLogin ? 'login' : 'settings'}"
40
+ target={isIframe ? "_blank" : ""}
41
  method="POST"
42
  >
43
  {#if $page.data.requiresLogin}
src/lib/components/MobileNav.svelte CHANGED
@@ -40,7 +40,7 @@
40
  bind:this={openEl}><CarbonTextAlignJustify /></button
41
  >
42
  <span class="truncate px-4">{title}</span>
43
- <a href={base || "/"} class="-mr-3 flex h-9 w-9 shrink-0 items-center justify-center"
44
  ><CarbonAdd /></a
45
  >
46
  </nav>
 
40
  bind:this={openEl}><CarbonTextAlignJustify /></button
41
  >
42
  <span class="truncate px-4">{title}</span>
43
+ <a href={`${base}/`} class="-mr-3 flex h-9 w-9 shrink-0 items-center justify-center"
44
  ><CarbonAdd /></a
45
  >
46
  </nav>
src/lib/components/NavMenu.svelte CHANGED
@@ -26,7 +26,7 @@
26
  HuggingChat
27
  </a>
28
  <a
29
- href={base || "/"}
30
  class="flex rounded-lg border bg-white px-2 py-0.5 text-center shadow-sm hover:shadow-none dark:border-gray-600 dark:bg-gray-700"
31
  >
32
  New Chat
 
26
  HuggingChat
27
  </a>
28
  <a
29
+ href={`${base}/`}
30
  class="flex rounded-lg border bg-white px-2 py-0.5 text-center shadow-sm hover:shadow-none dark:border-gray-600 dark:bg-gray-700"
31
  >
32
  New Chat
src/lib/server/auth.ts CHANGED
@@ -1,15 +1,13 @@
1
  import { Issuer, BaseClient, type UserinfoResponse, TokenSet } from "openid-client";
2
- import { addDays, addYears } from "date-fns";
3
  import {
4
  COOKIE_NAME,
5
  OPENID_CLIENT_ID,
6
  OPENID_CLIENT_SECRET,
7
  OPENID_PROVIDER_URL,
8
  } from "$env/static/private";
9
- import { PUBLIC_ORIGIN } from "$env/static/public";
10
  import { sha256 } from "$lib/utils/sha256";
11
  import { z } from "zod";
12
- import { base } from "$app/paths";
13
  import { dev } from "$app/environment";
14
  import type { Cookies } from "@sveltejs/kit";
15
 
@@ -35,8 +33,6 @@ export function refreshSessionCookie(cookies: Cookies, sessionId: string) {
35
  });
36
  }
37
 
38
- export const getRedirectURI = (url: URL) => `${PUBLIC_ORIGIN || url.origin}${base}/login/callback`;
39
-
40
  export const OIDC_SCOPES = "openid profile";
41
 
42
  export const authCondition = (locals: App.Locals) => {
@@ -48,8 +44,11 @@ export const authCondition = (locals: App.Locals) => {
48
  /**
49
  * Generates a CSRF token using the user sessionId. Note that we don't need a secret because sessionId is enough.
50
  */
51
- export async function generateCsrfToken(sessionId: string): Promise<string> {
52
- const data = { expiration: addDays(new Date(), 1).getTime() };
 
 
 
53
 
54
  return Buffer.from(
55
  JSON.stringify({
@@ -74,7 +73,7 @@ export async function getOIDCAuthorizationUrl(
74
  params: { sessionId: string }
75
  ): Promise<string> {
76
  const client = await getOIDCClient(settings);
77
- const csrfToken = await generateCsrfToken(params.sessionId);
78
  const url = client.authorizationUrl({
79
  scope: OIDC_SCOPES,
80
  state: csrfToken,
@@ -91,21 +90,30 @@ export async function getOIDCUserData(settings: OIDCSettings, code: string): Pro
91
  return { token, userData };
92
  }
93
 
94
- export async function validateCsrfToken(token: string, sessionId: string) {
 
 
 
 
 
 
95
  try {
96
  const { data, signature } = z
97
  .object({
98
  data: z.object({
99
  expiration: z.number().int(),
 
100
  }),
101
  signature: z.string().length(64),
102
  })
103
  .parse(JSON.parse(token));
104
  const reconstructSign = await sha256(JSON.stringify(data) + "##" + sessionId);
105
 
106
- return data.expiration > Date.now() && signature === reconstructSign;
 
 
107
  } catch (e) {
108
  console.error(e);
109
- return false;
110
  }
 
111
  }
 
1
  import { Issuer, BaseClient, type UserinfoResponse, TokenSet } from "openid-client";
2
+ import { addHours, addYears } from "date-fns";
3
  import {
4
  COOKIE_NAME,
5
  OPENID_CLIENT_ID,
6
  OPENID_CLIENT_SECRET,
7
  OPENID_PROVIDER_URL,
8
  } from "$env/static/private";
 
9
  import { sha256 } from "$lib/utils/sha256";
10
  import { z } from "zod";
 
11
  import { dev } from "$app/environment";
12
  import type { Cookies } from "@sveltejs/kit";
13
 
 
33
  });
34
  }
35
 
 
 
36
  export const OIDC_SCOPES = "openid profile";
37
 
38
  export const authCondition = (locals: App.Locals) => {
 
44
  /**
45
  * Generates a CSRF token using the user sessionId. Note that we don't need a secret because sessionId is enough.
46
  */
47
+ export async function generateCsrfToken(sessionId: string, redirectUrl: string): Promise<string> {
48
+ const data = {
49
+ expiration: addHours(new Date(), 1).getTime(),
50
+ redirectUrl,
51
+ };
52
 
53
  return Buffer.from(
54
  JSON.stringify({
 
73
  params: { sessionId: string }
74
  ): Promise<string> {
75
  const client = await getOIDCClient(settings);
76
+ const csrfToken = await generateCsrfToken(params.sessionId, settings.redirectURI);
77
  const url = client.authorizationUrl({
78
  scope: OIDC_SCOPES,
79
  state: csrfToken,
 
90
  return { token, userData };
91
  }
92
 
93
+ export async function validateAndParseCsrfToken(
94
+ token: string,
95
+ sessionId: string
96
+ ): Promise<{
97
+ /** This is the redirect url that was passed to the OIDC provider */
98
+ redirectUrl: string;
99
+ } | null> {
100
  try {
101
  const { data, signature } = z
102
  .object({
103
  data: z.object({
104
  expiration: z.number().int(),
105
+ redirectUrl: z.string().url(),
106
  }),
107
  signature: z.string().length(64),
108
  })
109
  .parse(JSON.parse(token));
110
  const reconstructSign = await sha256(JSON.stringify(data) + "##" + sessionId);
111
 
112
+ if (data.expiration > Date.now() && signature === reconstructSign) {
113
+ return { redirectUrl: data.redirectUrl };
114
+ }
115
  } catch (e) {
116
  console.error(e);
 
117
  }
118
+ return null;
119
  }
src/routes/+layout.svelte CHANGED
@@ -56,7 +56,7 @@
56
  if ($page.params.id !== id) {
57
  await invalidate(UrlDependency.ConversationList);
58
  } else {
59
- await goto(base || "/", { invalidateAll: true });
60
  }
61
  } catch (err) {
62
  console.error(err);
 
56
  if ($page.params.id !== id) {
57
  await invalidate(UrlDependency.ConversationList);
58
  } else {
59
+ await goto(`${base}/`, { invalidateAll: true });
60
  }
61
  } catch (err) {
62
  console.error(err);
src/routes/conversation/+server.ts CHANGED
@@ -57,5 +57,5 @@ export const POST: RequestHandler = async ({ locals, request }) => {
57
  };
58
 
59
  export const GET: RequestHandler = async () => {
60
- throw redirect(302, base || "/");
61
  };
 
57
  };
58
 
59
  export const GET: RequestHandler = async () => {
60
+ throw redirect(302, `${base}/`);
61
  };
src/routes/conversation/[id]/share/+server.ts CHANGED
@@ -1,5 +1,5 @@
1
  import { base } from "$app/paths";
2
- import { PUBLIC_ORIGIN } from "$env/static/public";
3
  import { authCondition } from "$lib/server/auth";
4
  import { collections } from "$lib/server/database";
5
  import type { SharedConversation } from "$lib/types/SharedConversation";
@@ -52,5 +52,5 @@ export async function POST({ params, url, locals }) {
52
  }
53
 
54
  function getShareUrl(url: URL, shareId: string): string {
55
- return `${PUBLIC_ORIGIN || url.origin}${base}/r/${shareId}`;
56
  }
 
1
  import { base } from "$app/paths";
2
+ import { PUBLIC_ORIGIN, PUBLIC_SHARE_PREFIX } from "$env/static/public";
3
  import { authCondition } from "$lib/server/auth";
4
  import { collections } from "$lib/server/database";
5
  import type { SharedConversation } from "$lib/types/SharedConversation";
 
52
  }
53
 
54
  function getShareUrl(url: URL, shareId: string): string {
55
+ return `${PUBLIC_SHARE_PREFIX || `${PUBLIC_ORIGIN || url.origin}${base}`}/r/${shareId}`;
56
  }
src/routes/login/+page.server.ts CHANGED
@@ -1,11 +1,13 @@
1
  import { redirect } from "@sveltejs/kit";
2
- import { getOIDCAuthorizationUrl, getRedirectURI } from "$lib/server/auth";
 
3
 
4
  export const actions = {
5
- default: async function ({ url, locals }) {
6
  // TODO: Handle errors if provider is not responding
 
7
  const authorizationUrl = await getOIDCAuthorizationUrl(
8
- { redirectURI: getRedirectURI(url) },
9
  { sessionId: locals.sessionId }
10
  );
11
 
 
1
  import { redirect } from "@sveltejs/kit";
2
+ import { getOIDCAuthorizationUrl } from "$lib/server/auth";
3
+ import { base } from "$app/paths";
4
 
5
  export const actions = {
6
+ default: async function ({ url, locals, request }) {
7
  // TODO: Handle errors if provider is not responding
8
+ const referer = request.headers.get("referer");
9
  const authorizationUrl = await getOIDCAuthorizationUrl(
10
+ { redirectURI: `${(referer ? new URL(referer) : url).origin}${base}/login/callback` },
11
  { sessionId: locals.sessionId }
12
  );
13
 
src/routes/login/callback/+server.ts CHANGED
@@ -1,5 +1,5 @@
1
  import { redirect, error } from "@sveltejs/kit";
2
- import { getOIDCUserData, getRedirectURI, validateCsrfToken } from "$lib/server/auth";
3
  import { z } from "zod";
4
  import { base } from "$app/paths";
5
  import { updateUser } from "./updateUser";
@@ -13,7 +13,7 @@ export async function GET({ url, locals, cookies }) {
13
 
14
  if (errorName) {
15
  // TODO: Display denied error on the UI
16
- throw redirect(302, base || "/");
17
  }
18
 
19
  const { code, state } = z
@@ -25,15 +25,15 @@ export async function GET({ url, locals, cookies }) {
25
 
26
  const csrfToken = Buffer.from(state, "base64").toString("utf-8");
27
 
28
- const isValidToken = await validateCsrfToken(csrfToken, locals.sessionId);
29
 
30
- if (!isValidToken) {
31
  throw error(403, "Invalid or expired CSRF token");
32
  }
33
 
34
- const { userData } = await getOIDCUserData({ redirectURI: getRedirectURI(url) }, code);
35
 
36
  await updateUser({ userData, locals, cookies });
37
 
38
- throw redirect(302, base || "/");
39
  }
 
1
  import { redirect, error } from "@sveltejs/kit";
2
+ import { getOIDCUserData, validateAndParseCsrfToken } from "$lib/server/auth";
3
  import { z } from "zod";
4
  import { base } from "$app/paths";
5
  import { updateUser } from "./updateUser";
 
13
 
14
  if (errorName) {
15
  // TODO: Display denied error on the UI
16
+ throw redirect(302, `${base}/`);
17
  }
18
 
19
  const { code, state } = z
 
25
 
26
  const csrfToken = Buffer.from(state, "base64").toString("utf-8");
27
 
28
+ const validatedToken = await validateAndParseCsrfToken(csrfToken, locals.sessionId);
29
 
30
+ if (!validatedToken) {
31
  throw error(403, "Invalid or expired CSRF token");
32
  }
33
 
34
+ const { userData } = await getOIDCUserData({ redirectURI: validatedToken.redirectUrl }, code);
35
 
36
  await updateUser({ userData, locals, cookies });
37
 
38
+ throw redirect(302, `${base}/`);
39
  }
src/routes/logout/+page.server.ts CHANGED
@@ -12,6 +12,6 @@ export const actions = {
12
  secure: !dev,
13
  httpOnly: true,
14
  });
15
- throw redirect(303, base || "/");
16
  },
17
  };
 
12
  secure: !dev,
13
  httpOnly: true,
14
  });
15
+ throw redirect(303, `${base}/`);
16
  },
17
  };
src/routes/settings/+page.server.ts CHANGED
@@ -41,6 +41,6 @@ export const actions = {
41
  }
42
  );
43
 
44
- throw redirect(303, request.headers.get("referer") || base || "/");
45
  },
46
  };
 
41
  }
42
  );
43
 
44
+ throw redirect(303, request.headers.get("referer") || `${base}/`);
45
  },
46
  };
svelte.config.js CHANGED
@@ -21,7 +21,7 @@ const config = {
21
  base: process.env.APP_BASE || "",
22
  },
23
  csrf: {
24
- // todo: fix
25
  checkOrigin: false,
26
  },
27
  },
 
21
  base: process.env.APP_BASE || "",
22
  },
23
  csrf: {
24
+ // handled in hooks.server.ts, because we can have multiple valid origins
25
  checkOrigin: false,
26
  },
27
  },