enzostvs HF staff commited on
Commit
c2c7576
·
1 Parent(s): df0db78

Add providers selector

Browse files
public/providers/fireworks-ai.svg ADDED
public/providers/hyperbolic.svg ADDED
public/providers/nebius.svg ADDED
public/providers/sambanova.svg ADDED
server.js CHANGED
@@ -8,6 +8,7 @@ import { InferenceClient } from "@huggingface/inference";
8
  import bodyParser from "body-parser";
9
 
10
  import checkUser from "./middlewares/checkUser.js";
 
11
 
12
  // Load environment variables from .env file
13
  dotenv.config();
@@ -174,7 +175,7 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
174
  });
175
 
176
  app.post("/api/ask-ai", async (req, res) => {
177
- const { prompt, html, previousPrompt } = req.body;
178
  if (!prompt) {
179
  return res.status(400).send({
180
  ok: false,
@@ -192,7 +193,6 @@ app.post("/api/ask-ai", async (req, res) => {
192
  "0.0.0.0";
193
 
194
  if (!hf_token) {
195
- // Rate limit requests from the same IP address, to prevent abuse, free is limited to 2 requests per IP
196
  ipAddresses.set(ip, (ipAddresses.get(ip) || 0) + 1);
197
  if (ipAddresses.get(ip) > MAX_REQUESTS_PER_IP) {
198
  return res.status(429).send({
@@ -213,10 +213,26 @@ app.post("/api/ask-ai", async (req, res) => {
213
  const client = new InferenceClient(token);
214
  let completeResponse = "";
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  try {
217
  const chatCompletion = client.chatCompletionStream({
218
  model: MODEL_ID,
219
- provider: "hyperbolic",
220
  messages: [
221
  {
222
  role: "system",
@@ -244,7 +260,7 @@ app.post("/api/ask-ai", async (req, res) => {
244
  content: prompt,
245
  },
246
  ],
247
- max_tokens: 12_000,
248
  });
249
 
250
  while (true) {
@@ -254,25 +270,37 @@ app.post("/api/ask-ai", async (req, res) => {
254
  }
255
  const chunk = value.choices[0]?.delta?.content;
256
  if (chunk) {
257
- res.write(chunk);
258
- completeResponse += chunk;
259
-
260
- // Break when HTML is complete
261
- if (completeResponse.includes("</html>")) {
262
- break;
 
 
 
 
 
 
 
 
 
 
 
 
263
  }
264
  }
265
  }
266
-
267
  // End the response stream
 
268
  res.end();
269
  } catch (error) {
270
- console.error("Error:", error);
271
- // If we haven't sent a response yet, send an error
272
  if (!res.headersSent) {
273
  res.status(500).send({
274
  ok: false,
275
- message: `You probably reached the MAX_TOKENS limit, context is too long. You can start a new conversation by refreshing the page.`,
 
 
276
  });
277
  } else {
278
  // Otherwise end the stream
 
8
  import bodyParser from "body-parser";
9
 
10
  import checkUser from "./middlewares/checkUser.js";
11
+ import { PROVIDERS } from "./utils/providers.js";
12
 
13
  // Load environment variables from .env file
14
  dotenv.config();
 
175
  });
176
 
177
  app.post("/api/ask-ai", async (req, res) => {
178
+ const { prompt, html, previousPrompt, provider } = req.body;
179
  if (!prompt) {
180
  return res.status(400).send({
181
  ok: false,
 
193
  "0.0.0.0";
194
 
195
  if (!hf_token) {
 
196
  ipAddresses.set(ip, (ipAddresses.get(ip) || 0) + 1);
197
  if (ipAddresses.get(ip) > MAX_REQUESTS_PER_IP) {
198
  return res.status(429).send({
 
213
  const client = new InferenceClient(token);
214
  let completeResponse = "";
215
 
216
+ const selectedProvider =
217
+ PROVIDERS.find((providerItem) => providerItem.id === provider) ??
218
+ PROVIDERS[0];
219
+
220
+ let TOKENS_USED = prompt?.length;
221
+ if (previousPrompt) TOKENS_USED += previousPrompt.length;
222
+ if (html) TOKENS_USED += html.length;
223
+
224
+ if (TOKENS_USED >= selectedProvider.max_tokens) {
225
+ return res.status(400).send({
226
+ ok: false,
227
+ openSelectProvider: true,
228
+ message: `Context is too long. ${selectedProvider.name} allow ${selectedProvider.max_tokens} max tokens.`,
229
+ });
230
+ }
231
+
232
  try {
233
  const chatCompletion = client.chatCompletionStream({
234
  model: MODEL_ID,
235
+ provider: selectedProvider.id,
236
  messages: [
237
  {
238
  role: "system",
 
260
  content: prompt,
261
  },
262
  ],
263
+ max_tokens: selectedProvider.max_tokens,
264
  });
265
 
266
  while (true) {
 
270
  }
271
  const chunk = value.choices[0]?.delta?.content;
272
  if (chunk) {
273
+ if (provider !== "sambanova") {
274
+ res.write(chunk);
275
+ completeResponse += chunk;
276
+
277
+ if (completeResponse.includes("</html>")) {
278
+ break;
279
+ }
280
+ } else {
281
+ let newChunk = chunk;
282
+ if (chunk.includes("</html>")) {
283
+ // Replace everything after the last </html> tag with an empty string
284
+ newChunk = newChunk.replace(/<\/html>[\s\S]*/, "</html>");
285
+ }
286
+ completeResponse += newChunk;
287
+ res.write(newChunk);
288
+ if (newChunk.includes("</html>")) {
289
+ break;
290
+ }
291
  }
292
  }
293
  }
 
294
  // End the response stream
295
+ // return the total_tokens used to the client
296
  res.end();
297
  } catch (error) {
 
 
298
  if (!res.headersSent) {
299
  res.status(500).send({
300
  ok: false,
301
+ // use generic error,
302
+ message:
303
+ "An error occurred while processing your request. Please try again or switch provider.",
304
  });
305
  } else {
306
  // Otherwise end the stream
src/components/App.tsx CHANGED
@@ -7,10 +7,10 @@ import { toast } from "react-toastify";
7
 
8
  import Header from "./header/header";
9
  import DeployButton from "./deploy-button/deploy-button";
10
- import { defaultHTML } from "../utils/consts";
11
  import Tabs from "./tabs/tabs";
12
  import AskAI from "./ask-ai/ask-ai";
13
- import { Auth } from "../utils/types";
14
  import Preview from "./preview/preview";
15
 
16
  function App() {
 
7
 
8
  import Header from "./header/header";
9
  import DeployButton from "./deploy-button/deploy-button";
10
+ import { defaultHTML } from "./../../utils/consts";
11
  import Tabs from "./tabs/tabs";
12
  import AskAI from "./ask-ai/ask-ai";
13
+ import { Auth } from "./../../utils/types";
14
  import Preview from "./preview/preview";
15
 
16
  function App() {
src/components/ask-ai/ask-ai.tsx CHANGED
@@ -5,8 +5,10 @@ import classNames from "classnames";
5
  import { toast } from "react-toastify";
6
 
7
  import Login from "../login/login";
8
- import { defaultHTML } from "../../utils/consts";
9
  import SuccessSound from "./../../assets/success.mp3";
 
 
10
 
11
  function AskAI({
12
  html,
@@ -25,12 +27,17 @@ function AskAI({
25
  const [prompt, setPrompt] = useState("");
26
  const [hasAsked, setHasAsked] = useState(false);
27
  const [previousPrompt, setPreviousPrompt] = useState("");
 
 
 
 
28
  const audio = new Audio(SuccessSound);
29
  audio.volume = 0.5;
30
 
31
  const callAi = async () => {
32
  if (isAiWorking || !prompt.trim()) return;
33
  setisAiWorking(true);
 
34
 
35
  let contentResponse = "";
36
  let lastRenderTime = 0;
@@ -39,6 +46,7 @@ function AskAI({
39
  method: "POST",
40
  body: JSON.stringify({
41
  prompt,
 
42
  ...(html === defaultHTML ? {} : { html }),
43
  ...(previousPrompt ? { previousPrompt } : {}),
44
  }),
@@ -51,8 +59,10 @@ function AskAI({
51
  const res = await request.json();
52
  if (res.openLogin) {
53
  setOpen(true);
 
 
 
54
  } else {
55
- // don't show toast if it's a login error
56
  toast.error(res.message);
57
  }
58
  setisAiWorking(false);
@@ -130,7 +140,7 @@ function AskAI({
130
  <input
131
  type="text"
132
  disabled={isAiWorking}
133
- className="w-full bg-transparent max-lg:text-sm outline-none pl-3 text-white placeholder:text-gray-500 font-code"
134
  placeholder={
135
  hasAsked ? "What do you want to ask AI next?" : "Ask AI anything..."
136
  }
@@ -142,13 +152,22 @@ function AskAI({
142
  }
143
  }}
144
  />
145
- <button
146
- disabled={isAiWorking}
147
- className="relative overflow-hidden cursor-pointer flex-none flex items-center justify-center rounded-full text-sm font-semibold size-8 text-center bg-pink-500 hover:bg-pink-400 text-white shadow-sm dark:shadow-highlight/20 disabled:bg-gray-300 disabled:text-gray-500 disabled:cursor-not-allowed disabled:hover:bg-gray-300"
148
- onClick={callAi}
149
- >
150
- <GrSend className="-translate-x-[1px]" />
151
- </button>
 
 
 
 
 
 
 
 
 
152
  </div>
153
  <div
154
  className={classNames(
 
5
  import { toast } from "react-toastify";
6
 
7
  import Login from "../login/login";
8
+ import { defaultHTML } from "./../../../utils/consts";
9
  import SuccessSound from "./../../assets/success.mp3";
10
+ import Settings from "../settings/settings";
11
+ import { useLocalStorage } from "react-use";
12
 
13
  function AskAI({
14
  html,
 
27
  const [prompt, setPrompt] = useState("");
28
  const [hasAsked, setHasAsked] = useState(false);
29
  const [previousPrompt, setPreviousPrompt] = useState("");
30
+ const [provider, setProvider] = useLocalStorage("provider", "fireworks-ai");
31
+ const [openProvider, setOpenProvider] = useState(false);
32
+ const [providerError, setProviderError] = useState("");
33
+
34
  const audio = new Audio(SuccessSound);
35
  audio.volume = 0.5;
36
 
37
  const callAi = async () => {
38
  if (isAiWorking || !prompt.trim()) return;
39
  setisAiWorking(true);
40
+ setProviderError("");
41
 
42
  let contentResponse = "";
43
  let lastRenderTime = 0;
 
46
  method: "POST",
47
  body: JSON.stringify({
48
  prompt,
49
+ provider,
50
  ...(html === defaultHTML ? {} : { html }),
51
  ...(previousPrompt ? { previousPrompt } : {}),
52
  }),
 
59
  const res = await request.json();
60
  if (res.openLogin) {
61
  setOpen(true);
62
+ } else if (res.openSelectProvider) {
63
+ setOpenProvider(true);
64
+ setProviderError(res.message);
65
  } else {
 
66
  toast.error(res.message);
67
  }
68
  setisAiWorking(false);
 
140
  <input
141
  type="text"
142
  disabled={isAiWorking}
143
+ className="w-full bg-transparent max-lg:text-sm outline-none px-3 text-white placeholder:text-gray-500 font-code"
144
  placeholder={
145
  hasAsked ? "What do you want to ask AI next?" : "Ask AI anything..."
146
  }
 
152
  }
153
  }}
154
  />
155
+ <div className="flex items-center justify-end gap-2">
156
+ <Settings
157
+ provider={provider as string}
158
+ onChange={setProvider}
159
+ open={openProvider}
160
+ error={providerError}
161
+ onClose={setOpenProvider}
162
+ />
163
+ <button
164
+ disabled={isAiWorking}
165
+ className="relative overflow-hidden cursor-pointer flex-none flex items-center justify-center rounded-full text-sm font-semibold size-8 text-center bg-pink-500 hover:bg-pink-400 text-white shadow-sm dark:shadow-highlight/20 disabled:bg-gray-300 disabled:text-gray-500 disabled:cursor-not-allowed disabled:hover:bg-gray-300"
166
+ onClick={callAi}
167
+ >
168
+ <GrSend className="-translate-x-[1px]" />
169
+ </button>
170
+ </div>
171
  </div>
172
  <div
173
  className={classNames(
src/components/deploy-button/deploy-button.tsx CHANGED
@@ -6,7 +6,7 @@ import { toast } from "react-toastify";
6
  import SpaceIcon from "@/assets/space.svg";
7
  import Loading from "../loading/loading";
8
  import Login from "../login/login";
9
- import { Auth } from "../../utils/types";
10
 
11
  const MsgToast = ({ url }: { url: string }) => (
12
  <div className="w-full flex items-center justify-center gap-3">
 
6
  import SpaceIcon from "@/assets/space.svg";
7
  import Loading from "../loading/loading";
8
  import Login from "../login/login";
9
+ import { Auth } from "./../../../utils/types";
10
 
11
  const MsgToast = ({ url }: { url: string }) => (
12
  <div className="w-full flex items-center justify-center gap-3">
src/components/login/login.tsx CHANGED
@@ -1,5 +1,5 @@
1
  import { useLocalStorage } from "react-use";
2
- import { defaultHTML } from "../../utils/consts";
3
 
4
  function Login({
5
  html,
 
1
  import { useLocalStorage } from "react-use";
2
+ import { defaultHTML } from "./../../../utils/consts";
3
 
4
  function Login({
5
  html,
src/components/settings/settings.tsx CHANGED
@@ -1,17 +1,32 @@
1
- import { useState } from "react";
2
  import classNames from "classnames";
3
- import Login from "../login/login";
4
 
5
- function Settings() {
6
- const [open, setOpen] = useState(false);
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  return (
9
- <div className="relative">
10
  <button
11
- className="bg-gray-800/70 rounded-md text-xs text-gray-300 hover:brightness-125 px-3 py-1.5 font-medium cursor-pointer"
12
- onClick={() => setOpen(!open)}
 
 
13
  >
14
- Settings
15
  </button>
16
  <div
17
  className={classNames(
@@ -20,17 +35,73 @@ function Settings() {
20
  "opacity-0 pointer-events-none": !open,
21
  }
22
  )}
23
- onClick={() => setOpen(false)}
24
  ></div>
25
  <div
26
  className={classNames(
27
- "absolute top-[calc(100%+8px)] right-0 z-10 w-80 bg-white border border-gray-200 rounded-lg shadow-lg transition-all duration-75 overflow-hidden",
28
  {
29
  "opacity-0 pointer-events-none": !open,
30
  }
31
  )}
32
  >
33
- <Login />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  </div>
35
  </div>
36
  );
 
1
+ /* eslint-disable @typescript-eslint/no-explicit-any */
2
  import classNames from "classnames";
 
3
 
4
+ import { PiGearSixFill } from "react-icons/pi";
5
+ // @ts-expect-error not needed
6
+ import { PROVIDERS } from "./../../../utils/providers";
7
 
8
+ function Settings({
9
+ open,
10
+ onClose,
11
+ provider,
12
+ error,
13
+ onChange,
14
+ }: {
15
+ open: boolean;
16
+ provider: string;
17
+ error?: string;
18
+ onClose: React.Dispatch<React.SetStateAction<boolean>>;
19
+ onChange: (provider: string) => void;
20
+ }) {
21
  return (
22
+ <div className="">
23
  <button
24
+ className="relative overflow-hidden cursor-pointer flex-none flex items-center justify-center rounded-full text-base font-semibold size-8 text-center bg-gray-800 hover:bg-gray-700 text-gray-100 shadow-sm dark:shadow-highlight/20"
25
+ onClick={() => {
26
+ onClose((prev) => !prev);
27
+ }}
28
  >
29
+ <PiGearSixFill />
30
  </button>
31
  <div
32
  className={classNames(
 
35
  "opacity-0 pointer-events-none": !open,
36
  }
37
  )}
38
+ onClick={() => onClose(false)}
39
  ></div>
40
  <div
41
  className={classNames(
42
+ "absolute top-0 -translate-y-[calc(100%+16px)] right-0 z-10 w-96 bg-white border border-gray-200 rounded-lg shadow-lg transition-all duration-75 overflow-hidden",
43
  {
44
  "opacity-0 pointer-events-none": !open,
45
  }
46
  )}
47
  >
48
+ <header className="flex items-center text-sm px-4 py-2 border-b border-gray-200 gap-2 bg-gray-100 font-semibold text-gray-700">
49
+ <span className="text-xs bg-blue-500/10 text-blue-500 rounded-full pl-1.5 pr-2.5 py-0.5 flex items-center justify-start gap-1.5">
50
+ Provider
51
+ </span>
52
+ Customize Settings
53
+ </header>
54
+ <main className="px-4 pt-3 pb-4 space-y-3">
55
+ {error !== "" && (
56
+ <p className="text-red-500 text-sm font-medium mb-2 flex items-center justify-between bg-red-500/10 p-2 rounded-md">
57
+ {error}
58
+ </p>
59
+ )}
60
+ <label className="block">
61
+ <p className="text-gray-500 text-sm font-medium mb-2 flex items-center justify-between">
62
+ Inference Provider
63
+ </p>
64
+ <div className="grid grid-cols-2 gap-1.5">
65
+ {PROVIDERS.map((item: any) => (
66
+ <div
67
+ key={item.id}
68
+ className={classNames(
69
+ "text-gray-600 text-sm font-medium cursor-pointer border p-2 rounded-md flex items-center justify-start gap-2",
70
+ {
71
+ "bg-blue-500/10 border-blue-500/15 text-blue-500":
72
+ item.id === provider,
73
+ "hover:bg-gray-100 border-gray-100": item.id !== provider,
74
+ }
75
+ )}
76
+ onClick={() => {
77
+ onChange(item.id);
78
+ }}
79
+ >
80
+ <img
81
+ src={`/providers/${item.id}.svg`}
82
+ alt={item.name}
83
+ className="size-5"
84
+ />
85
+ {item.name}
86
+ </div>
87
+ ))}
88
+ </div>
89
+ {/* <input
90
+ type="password"
91
+ autoComplete="off"
92
+ className="mr-2 border rounded-md px-3 py-1.5 border-gray-300 w-full text-sm"
93
+ placeholder="hf_******"
94
+ value={tokenStorage[0] as string}
95
+ onChange={(e) => {
96
+ if (e.target.value.length > 0) {
97
+ tokenStorage[1](e.target.value);
98
+ } else {
99
+ tokenStorage[2]();
100
+ }
101
+ }}
102
+ /> */}
103
+ </label>
104
+ </main>
105
  </div>
106
  </div>
107
  );
src/components/tabs/tabs.tsx CHANGED
@@ -14,7 +14,7 @@ function Tabs({ children }: { children?: React.ReactNode }) {
14
  </div>
15
  <div className="flex items-center justify-end gap-3">
16
  <a
17
- href="https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?inference_provider=fireworks-ai"
18
  target="_blank"
19
  className="text-[12px] text-gray-300 hover:brightness-120 flex items-center gap-1 font-code"
20
  >
 
14
  </div>
15
  <div className="flex items-center justify-end gap-3">
16
  <a
17
+ href="https://huggingface.co/deepseek-ai/DeepSeek-V3-0324"
18
  target="_blank"
19
  className="text-[12px] text-gray-300 hover:brightness-120 flex items-center gap-1 font-code"
20
  >
tsconfig.app.json CHANGED
@@ -22,5 +22,5 @@
22
  "noFallthroughCasesInSwitch": true,
23
  "noUncheckedSideEffectImports": true
24
  },
25
- "include": ["src", "middleware"]
26
  }
 
22
  "noFallthroughCasesInSwitch": true,
23
  "noUncheckedSideEffectImports": true
24
  },
25
+ "include": ["src", "middleware", "utils/consts.ts", "utils/types.ts"]
26
  }
{src/utils → utils}/consts.ts RENAMED
File without changes
utils/providers.js ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export const PROVIDERS = [
2
+ {
3
+ name: "Fireworks AI",
4
+ id: "fireworks-ai",
5
+ max_tokens: 131_000,
6
+ },
7
+ {
8
+ name: "Nebius AI Studio",
9
+ id: "nebius",
10
+ max_tokens: 131_000,
11
+ },
12
+ {
13
+ name: "SambaNova",
14
+ id: "sambanova",
15
+ max_tokens: 8_000,
16
+ },
17
+ {
18
+ name: "Hyperbolic",
19
+ id: "hyperbolic",
20
+ max_tokens: 131_000,
21
+ },
22
+ ];
{src/utils → utils}/types.ts RENAMED
File without changes