ccm commited on
Commit
30d4d88
·
verified ·
1 Parent(s): 2ce9b97

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +30 -15
main.py CHANGED
@@ -7,6 +7,7 @@ import pandas # to work with pandas
7
  import json # to work with JSON
8
  import datasets # to load the dataset
9
  import spaces # for GPU
 
10
 
11
  # Load the dataset and convert to pandas
12
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
@@ -60,12 +61,12 @@ def search(query: str, k: int) -> tuple[str]:
60
 
61
 
62
  # Create an LLM pipeline that we can send queries to
63
- pipe = transformers.pipeline(
64
- "text-generation",
65
- model="Qwen/Qwen2-0.5B-Instruct",
66
- trust_remote_code=True,
67
- max_new_tokens = 512,
68
- device="cuda:0",
69
  )
70
 
71
  def preprocess(message: str) -> tuple[str]:
@@ -77,7 +78,6 @@ def postprocess(response: str, bypass_from_preprocessing: str) -> str:
77
  """Applies a postprocessing step to the LLM's response before the user receives it"""
78
  return response + bypass_from_preprocessing
79
 
80
- @spaces.GPU
81
  def predict(message: str, history: list[str]) -> str:
82
  """This function is responsible for crafting a response"""
83
 
@@ -93,14 +93,29 @@ def predict(message: str, history: list[str]) -> str:
93
  for idx, msg in enumerate(history)
94
  ] + [{"role": "user", "content": message}]
95
 
96
- # Create a response
97
- response = pipe(history_transformer_format)
98
- response_message = response[0]["generated_text"][-1]["content"]
99
-
100
- # Apply postprocessing
101
- response_message = postprocess(response_message, bypass)
102
-
103
- return response_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  # Create and run the gradio interface
 
7
  import json # to work with JSON
8
  import datasets # to load the dataset
9
  import spaces # for GPU
10
+ import threading
11
 
12
  # Load the dataset and convert to pandas
13
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
 
61
 
62
 
63
  # Create an LLM pipeline that we can send queries to
64
+ tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
65
+ streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
66
+ chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
67
+ "Qwen/Qwen2-0.5B-Instruct"
68
+ torch_dtype="auto",
69
+ device_map="auto"
70
  )
71
 
72
  def preprocess(message: str) -> tuple[str]:
 
78
  """Applies a postprocessing step to the LLM's response before the user receives it"""
79
  return response + bypass_from_preprocessing
80
 
 
81
  def predict(message: str, history: list[str]) -> str:
82
  """This function is responsible for crafting a response"""
83
 
 
93
  for idx, msg in enumerate(history)
94
  ] + [{"role": "user", "content": message}]
95
 
96
+ # Stream a response from pipe
97
+ text = tokenizer.apply_chat_template(
98
+ history_transformer_format,
99
+ tokenize=False,
100
+ add_generation_prompt=True
101
+ )
102
+ model_inputs = tokenizer([text], return_tensors="pt")
103
+
104
+ generate_kwargs = dict(
105
+ model_inputs,
106
+ streamer=streamer,
107
+ max_new_tokens=512
108
+ )
109
+ t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
110
+ t.start()
111
+
112
+ partial_message = ""
113
+ for new_token in streamer:
114
+ if new_token != '<':
115
+ partial_message += new_token
116
+ yield partial_message
117
+
118
+ yield bypass
119
 
120
 
121
  # Create and run the gradio interface