michailroussos commited on
Commit
ebd9e26
·
1 Parent(s): e8ace7a

more changes

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -36,30 +36,24 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
36
  return_tensors="pt",
37
  ).to("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
- # Use TextStreamer to process and yield outputs incrementally
40
- class GradioStreamer(TextStreamer):
41
- def __init__(self, tokenizer, *args, **kwargs):
42
- super().__init__(tokenizer, *args, **kwargs)
43
- self.generated_text = ""
44
 
45
- def on_token(self, token_id):
46
- token = self.tokenizer.decode(token_id, skip_special_tokens=True)
47
- self.generated_text += token
48
- yield self.generated_text
49
-
50
- # Initialize Gradio-compatible streamer
51
- streamer = GradioStreamer(tokenizer, skip_prompt=True)
52
-
53
- # Generate response with streaming
54
- _ = model.generate(
55
  input_ids=inputs,
 
56
  max_new_tokens=max_tokens,
57
  use_cache=True,
58
  temperature=temperature,
59
  top_p=top_p,
60
- streamer=streamer,
61
  )
62
 
 
 
 
 
 
 
63
 
64
 
65
  # Define the Gradio interface
 
36
  return_tensors="pt",
37
  ).to("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
+ attention_mask = inputs.ne(tokenizer.pad_token_id).long() # Explicitly set attention mask
 
 
 
 
40
 
41
+ # Generate response tokens
42
+ generated_tokens = model.generate(
 
 
 
 
 
 
 
 
43
  input_ids=inputs,
44
+ attention_mask=attention_mask,
45
  max_new_tokens=max_tokens,
46
  use_cache=True,
47
  temperature=temperature,
48
  top_p=top_p,
 
49
  )
50
 
51
+ # Decode generated tokens
52
+ response = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
53
+
54
+ # Yield response in the required Gradio format
55
+ yield [{"role": "assistant", "content": response}]
56
+
57
 
58
 
59
  # Define the Gradio interface