bobpopboom commited on
Commit
c215361
·
verified ·
1 Parent(s): 9c7a437
Files changed (1) hide show
  1. app.py +35 -48
app.py CHANGED
@@ -1,65 +1,52 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
 
5
- # Determine device
6
- device = "cuda" if torch.cuda.is_available() else "cpu"
7
 
8
- model_id = "mradermacher/TinyLlama-Friendly-Psychotherapist-GGUF"
9
 
10
  try:
11
- # Load model with appropriate settings
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- device_map="auto",
15
- torch_dtype=torch.float16,
16
- low_cpu_mem_usage=True,
17
- max_memory={0: "15GiB"} if torch.cuda.is_available() else None,
18
- offload_folder="offload",
19
- ).eval()
20
-
21
- tokenizer = AutoTokenizer.from_pretrained(model_id)
22
  tokenizer.pad_token = tokenizer.eos_token
23
- tokenizer.model_max_length = 4096 # Set to model's actual context length
24
 
 
 
 
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
27
  exit()
28
 
29
  def generate_text_streaming(prompt, max_new_tokens=128):
30
- inputs = tokenizer(
31
- prompt,
32
- return_tensors="pt",
33
- truncation=True,
34
- max_length=4096 # Match model's context length
35
- ).to(model.device)
36
-
37
  generated_tokens = []
38
- with torch.no_grad():
39
- for _ in range(max_new_tokens):
40
- outputs = model.generate(
41
- **inputs,
42
- max_new_tokens=1,
43
- do_sample=False,
44
- eos_token_id=tokenizer.eos_token_id,
45
- return_dict_in_generate=True
46
- )
47
-
48
- new_token = outputs.sequences[0, -1]
49
- generated_tokens.append(new_token)
50
-
51
- # Update inputs for next iteration
52
- inputs = {
53
- "input_ids": torch.cat([inputs["input_ids"], new_token.unsqueeze(0).unsqueeze(0)], dim=-1),
54
- "attention_mask": torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=model.device)], dim=-1)
55
- }
56
-
57
- # Decode the accumulated tokens
58
- current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
59
- yield current_text # Yield the full text so far
60
-
61
- if new_token == tokenizer.eos_token_id:
62
- break
63
 
64
  def respond(message, history, system_message, max_tokens):
65
  # Build prompt with full history
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer
3
+ import ctranslate2
4
  import torch
5
 
6
+ # Determine device (ctranslate2 handles device placement internally)
7
+ device = "cuda" if torch.cuda.is_available() else "cpu" # Still useful for other ops
8
 
9
+ model_path = "mradermacher/TinyLlama-Friendly-Psychotherapist-GGUF" # Path to your GGUF model
10
 
11
  try:
12
+ # 1. Load the tokenizer (same as before)
13
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
14
  tokenizer.pad_token = tokenizer.eos_token
15
+ tokenizer.model_max_length = 4096
16
 
17
+ # 2. Load the ctranslate2 model
18
+ ct_model = ctranslate2.Translator(model_path) # Load the GGUF model
19
+ ct_model.eval()
20
  except Exception as e:
21
  print(f"Error loading model: {e}")
22
  exit()
23
 
24
  def generate_text_streaming(prompt, max_new_tokens=128):
25
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(device)
26
+
 
 
 
 
 
27
  generated_tokens = []
28
+
29
+ for _ in range(max_new_tokens):
30
+ # ctranslate2 generation (adjust as needed)
31
+ outputs = ct_model.translate_batch(
32
+ inputs.input_ids.tolist(), # ctranslate2 needs list of token ids
33
+ max_length=1, # Generate one token at a time
34
+ beam_size=1, # Greedy decoding
35
+ ).eval()
36
+
37
+ new_token_id = outputs[0][0][-1] # Extract the generated token ID
38
+ new_token = tokenizer.decode(new_token_id, skip_special_tokens=True)
39
+
40
+ if new_token_id == tokenizer.eos_token_id:
41
+ break
42
+
43
+ generated_tokens.append(new_token_id)
44
+
45
+ current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
46
+ yield current_text
47
+
48
+ inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([[new_token_id]], device=inputs["input_ids"].device)], dim=-1)
49
+ inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=inputs["attention_mask"].device)], dim=-1)
 
 
 
50
 
51
  def respond(message, history, system_message, max_tokens):
52
  # Build prompt with full history