Mahavaury2 commited on
Commit
eb04a36
·
verified ·
1 Parent(s): 30169f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -43
app.py CHANGED
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python
2
-
3
  import os
4
  from collections.abc import Iterator
5
  from threading import Thread
@@ -8,32 +6,24 @@ import gradio as gr
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
- #
12
- # 1) Custom Pastel Gradient CSS, and force text to black
13
- #
14
  CUSTOM_CSS = """
15
  .gradio-container {
16
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
17
- color: black; /* ensure text appears in black */
18
  }
19
  """
20
 
21
- #
22
- # 2) Description: "Bonjour Dans le chat du consentement" in black
23
- # Also add a CPU notice in black if no GPU is found.
24
- #
25
  DESCRIPTION = """# Bonjour Dans le chat du consentement
26
  Mistral-7B Instruct Demo
27
  """
28
 
29
- if not torch.cuda.is_available():
30
- DESCRIPTION += "Running on CPU - This is likely too large to run effectively.\n"
31
 
32
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
33
 
34
- #
35
- # 3) Load Mistral-7B Instruct (requires gating, GPU recommended)
36
- #
37
  if torch.cuda.is_available():
38
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
39
  tokenizer = AutoTokenizer.from_pretrained(
@@ -46,26 +36,25 @@ if torch.cuda.is_available():
46
  device_map="auto",
47
  trust_remote_code=True
48
  )
 
 
 
 
 
 
 
 
 
49
 
50
- def generate(
51
- message: str,
52
- chat_history: list[dict],
53
- ) -> Iterator[str]:
54
- """
55
- Minimal chat generation function: no sliders, no extra params.
56
- """
57
  conversation = [*chat_history, {"role": "user", "content": message}]
58
 
59
- # Convert conversation to tokens
60
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
61
- # If it exceeds max token length, trim
62
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
63
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
64
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
65
 
66
  input_ids = input_ids.to(model.device)
67
 
68
- # Use a streamer to yield tokens as they are generated
69
  streamer = TextIteratorStreamer(
70
  tokenizer,
71
  timeout=20.0,
@@ -73,18 +62,16 @@ def generate(
73
  skip_special_tokens=True
74
  )
75
 
76
- # Basic generation settings (feel free to adjust if you want)
77
- generate_kwargs = dict(
78
- input_ids=input_ids,
79
- streamer=streamer,
80
- max_new_tokens=512, # Adjust if you want more or fewer tokens
81
- do_sample=True,
82
- temperature=0.7,
83
- top_p=0.9,
84
- repetition_penalty=1.1,
85
- )
86
 
87
- # Run generation in a background thread
88
  t = Thread(target=model.generate, kwargs=generate_kwargs)
89
  t.start()
90
 
@@ -93,16 +80,11 @@ def generate(
93
  outputs.append(text)
94
  yield "".join(outputs)
95
 
96
- #
97
- # 4) Build the Chat Interface
98
- # - No additional sliders
99
- # - No pre-filled example questions
100
- #
101
  demo = gr.ChatInterface(
102
  fn=generate,
103
  description=DESCRIPTION,
104
  css=CUSTOM_CSS,
105
- examples=None, # remove example prompts
106
  type="messages"
107
  )
108
 
 
 
 
1
  import os
2
  from collections.abc import Iterator
3
  from threading import Thread
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
 
 
 
9
  CUSTOM_CSS = """
10
  .gradio-container {
11
  background: linear-gradient(to right, #FFDEE9, #B5FFFC);
12
+ color: black;
13
  }
14
  """
15
 
 
 
 
 
16
  DESCRIPTION = """# Bonjour Dans le chat du consentement
17
  Mistral-7B Instruct Demo
18
  """
19
 
20
+ MAX_INPUT_TOKEN_LENGTH = 4096 # just a default
 
21
 
22
+ # Define model/tokenizer at the top so they're visible in all scopes
23
+ tokenizer = None
24
+ model = None
25
 
26
+ # Try to load the model only if GPU is available
 
 
27
  if torch.cuda.is_available():
28
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
29
  tokenizer = AutoTokenizer.from_pretrained(
 
36
  device_map="auto",
37
  trust_remote_code=True
38
  )
39
+ else:
40
+ # Show a warning in the description
41
+ DESCRIPTION += "\n**Running on CPU** — This model is too large for CPU inference!"
42
+
43
+ def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
44
+ # If there's no GPU (thus no tokenizer/model), return an error
45
+ if tokenizer is None or model is None:
46
+ yield "Error: No GPU available. Unable to load Mistral-7B-Instruct."
47
+ return
48
 
 
 
 
 
 
 
 
49
  conversation = [*chat_history, {"role": "user", "content": message}]
50
 
 
51
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
52
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
53
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
54
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
55
 
56
  input_ids = input_ids.to(model.device)
57
 
 
58
  streamer = TextIteratorStreamer(
59
  tokenizer,
60
  timeout=20.0,
 
62
  skip_special_tokens=True
63
  )
64
 
65
+ generate_kwargs = {
66
+ "input_ids": input_ids,
67
+ "streamer": streamer,
68
+ "max_new_tokens": 512,
69
+ "do_sample": True,
70
+ "temperature": 0.7,
71
+ "top_p": 0.9,
72
+ "repetition_penalty": 1.1,
73
+ }
 
74
 
 
75
  t = Thread(target=model.generate, kwargs=generate_kwargs)
76
  t.start()
77
 
 
80
  outputs.append(text)
81
  yield "".join(outputs)
82
 
 
 
 
 
 
83
  demo = gr.ChatInterface(
84
  fn=generate,
85
  description=DESCRIPTION,
86
  css=CUSTOM_CSS,
87
+ examples=None,
88
  type="messages"
89
  )
90