miracFence commited on
Commit
a295415
Β·
verified Β·
1 Parent(s): 1053707

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -28
app.py CHANGED
@@ -26,33 +26,33 @@ model = AutoModelForCausalLM.from_pretrained(model_name,
26
  device_map="auto")
27
  model.eval()
28
 
29
- def format_history(msg: str, history: list[list[str, str]], system_prompt: str):
30
- chat_history = system_prompt
31
- for query, response in history:
32
- chat_history += f"\nUser: {query}\nAssistant: {response}"
33
- chat_history += f"\nUser: {msg}\nAssistant:"
34
- return chat_history
35
-
36
  @spaces.GPU(duration=90)
37
- def generate(msg: str,
38
- history: list[list[str, str]],
39
- system_prompt: str,
40
- max_new_tokens: int = 1024,
41
- temperature: float = 0.6,
42
- top_p: float = 0.9,
43
- top_k: int = 50,
44
- repetition_penalty: float = 1.2,) -> Iterator[str]:
45
- chat_history = format_history(msg, history, system_prompt)
46
-
47
- # Tokenize the input prompt
48
- input_ids = tokenizer(chat_history, return_tensors="pt").to("cuda")
49
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
50
-
51
- # Generate a response using the model
52
- # outputs = model.generate(inputs["input_ids"], max_length=500, pad_token_id=tokenizer.eos_token_id)
 
 
 
 
 
 
 
 
53
 
54
- # Decode the response back to a string
55
- # response = tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True)
56
  generate_kwargs = dict(
57
  {"input_ids": input_ids},
58
  streamer=streamer,
@@ -66,9 +66,7 @@ def generate(msg: str,
66
  )
67
  t = Thread(target=model.generate, kwargs=generate_kwargs)
68
  t.start()
69
-
70
- # Yield the generated response
71
- #yield response
72
  outputs = []
73
  for text in streamer:
74
  outputs.append(text)
 
26
  device_map="auto")
27
  model.eval()
28
 
 
 
 
 
 
 
 
29
  @spaces.GPU(duration=90)
30
+ def generate(
31
+ message: str,
32
+ chat_history: list[tuple[str, str]],
33
+ max_new_tokens: int = 1024,
34
+ temperature: float = 0.6,
35
+ top_p: float = 0.9,
36
+ top_k: int = 50,
37
+ repetition_penalty: float = 1.2,
38
+ ) -> Iterator[str]:
39
+ conversation = []
40
+ for user, assistant in chat_history:
41
+ conversation.extend(
42
+ [
43
+ {"role": "user", "content": user},
44
+ {"role": "assistant", "content": assistant},
45
+ ]
46
+ )
47
+ conversation.append({"role": "user", "content": message})
48
+
49
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
50
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
51
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
52
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
53
+ input_ids = input_ids.to(model.device)
54
 
55
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
56
  generate_kwargs = dict(
57
  {"input_ids": input_ids},
58
  streamer=streamer,
 
66
  )
67
  t = Thread(target=model.generate, kwargs=generate_kwargs)
68
  t.start()
69
+
 
 
70
  outputs = []
71
  for text in streamer:
72
  outputs.append(text)