miracFence commited on
Commit
fd4a241
·
verified ·
1 Parent(s): 2ec390b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -51
app.py CHANGED
@@ -1,12 +1,8 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
3
  import torch
4
  import spaces
5
 
6
- import os
7
- from threading import Thread
8
- from typing import Iterator
9
-
10
  # Define quantization configuration
11
  quantization_config = BitsAndBytesConfig(
12
  load_in_4bit=True, # Specify 4-bit quantization
@@ -18,64 +14,39 @@ quantization_config = BitsAndBytesConfig(
18
  # Load the tokenizer and quantized model from Hugging Face
19
  model_name = "llSourcell/medllama2_7b"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
  # Load model with quantization
24
  model = AutoModelForCausalLM.from_pretrained(model_name,
25
  quantization_config=quantization_config,
26
  device_map="auto")
27
  model.eval()
28
- max_token_length = 4096
29
-
30
- @spaces.GPU(duration=15)
31
- def generate(
32
- message: str,
33
- chat_history: list[tuple[str, str]],
34
- max_new_tokens: int = 1024,
35
- temperature: float = 0.6,
36
- top_p: float = 0.9,
37
- top_k: int = 50,
38
- repetition_penalty: float = 1.2,
39
- ) -> Iterator[str]:
40
- conversation = []
41
- for user, assistant in chat_history:
42
- conversation.extend(
43
- [
44
- {"role": "user", "content": user},
45
- {"role": "assistant", "content": assistant},
46
- ]
47
- )
48
- conversation.append({"role": "user", "content": message})
49
 
50
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
51
- if input_ids.shape[1] > max_token_length:
52
- input_ids = input_ids[:, -max_token_length:]
53
- gr.Warning(f"Trimmed input from conversation as it was longer than {max_token_length} tokens.")
54
- input_ids = input_ids.to(model.device)
 
55
 
56
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
57
- generate_kwargs = dict(
58
- {"input_ids": input_ids},
59
- streamer=streamer,
60
- max_new_tokens=max_new_tokens,
61
- do_sample=True,
62
- top_p=top_p,
63
- top_k=top_k,
64
- temperature=temperature,
65
- num_beams=1,
66
- repetition_penalty=repetition_penalty,
67
- )
68
- t = Thread(target=model.generate, kwargs=generate_kwargs)
69
- t.start()
70
 
71
- outputs = []
72
- for text in streamer:
73
- outputs.append(text)
74
- yield "".join(outputs)
 
75
 
76
  # Define the Gradio ChatInterface
77
  chatbot = gr.ChatInterface(
78
- fn=generate,
79
  chatbot=gr.Chatbot(
80
  height="64vh"
81
  ),
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
  import spaces
5
 
 
 
 
 
6
  # Define quantization configuration
7
  quantization_config = BitsAndBytesConfig(
8
  load_in_4bit=True, # Specify 4-bit quantization
 
14
  # Load the tokenizer and quantized model from Hugging Face
15
  model_name = "llSourcell/medllama2_7b"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
17
 
18
  # Load model with quantization
19
  model = AutoModelForCausalLM.from_pretrained(model_name,
20
  quantization_config=quantization_config,
21
  device_map="auto")
22
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ def format_history(msg: str, history: list[list[str, str]], system_prompt: str):
25
+ chat_history = system_prompt
26
+ for query, response in history:
27
+ chat_history += f"\nUser: {query}\nAssistant: {response}"
28
+ chat_history += f"\nUser: {msg}\nAssistant:"
29
+ return chat_history
30
 
31
+ @spaces.GPU(duration=30)
32
+ def generate_response(msg: str, history: list[list[str, str]], system_prompt: str):
33
+ chat_history = format_history(msg, history, system_prompt)
34
+
35
+ # Tokenize the input prompt
36
+ inputs = tokenizer(chat_history, return_tensors="pt").to("cuda")
37
+
38
+ # Generate a response using the model
39
+ outputs = model.generate(inputs["input_ids"], max_length=1024, pad_token_id=tokenizer.eos_token_id)
 
 
 
 
 
40
 
41
+ # Decode the response back to a string
42
+ response = tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True)
43
+
44
+ # Yield the generated response
45
+ yield response
46
 
47
  # Define the Gradio ChatInterface
48
  chatbot = gr.ChatInterface(
49
+ generate_response,
50
  chatbot=gr.Chatbot(
51
  height="64vh"
52
  ),