Spaces:
Running
on
Zero
Running
on
Zero
Update main.py
Browse files
main.py
CHANGED
@@ -7,6 +7,7 @@ import pandas # to work with pandas
|
|
7 |
import json # to work with JSON
|
8 |
import datasets # to load the dataset
|
9 |
import spaces # for GPU
|
|
|
10 |
|
11 |
# Load the dataset and convert to pandas
|
12 |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
@@ -60,12 +61,12 @@ def search(query: str, k: int) -> tuple[str]:
|
|
60 |
|
61 |
|
62 |
# Create an LLM pipeline that we can send queries to
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
)
|
70 |
|
71 |
def preprocess(message: str) -> tuple[str]:
|
@@ -77,7 +78,6 @@ def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
|
77 |
"""Applies a postprocessing step to the LLM's response before the user receives it"""
|
78 |
return response + bypass_from_preprocessing
|
79 |
|
80 |
-
@spaces.GPU
|
81 |
def predict(message: str, history: list[str]) -> str:
|
82 |
"""This function is responsible for crafting a response"""
|
83 |
|
@@ -93,14 +93,29 @@ def predict(message: str, history: list[str]) -> str:
|
|
93 |
for idx, msg in enumerate(history)
|
94 |
] + [{"role": "user", "content": message}]
|
95 |
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
|
106 |
# Create and run the gradio interface
|
|
|
7 |
import json # to work with JSON
|
8 |
import datasets # to load the dataset
|
9 |
import spaces # for GPU
|
10 |
+
import threading
|
11 |
|
12 |
# Load the dataset and convert to pandas
|
13 |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
|
|
61 |
|
62 |
|
63 |
# Create an LLM pipeline that we can send queries to
|
64 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
65 |
+
streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
66 |
+
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
67 |
+
"Qwen/Qwen2-0.5B-Instruct"
|
68 |
+
torch_dtype="auto",
|
69 |
+
device_map="auto"
|
70 |
)
|
71 |
|
72 |
def preprocess(message: str) -> tuple[str]:
|
|
|
78 |
"""Applies a postprocessing step to the LLM's response before the user receives it"""
|
79 |
return response + bypass_from_preprocessing
|
80 |
|
|
|
81 |
def predict(message: str, history: list[str]) -> str:
|
82 |
"""This function is responsible for crafting a response"""
|
83 |
|
|
|
93 |
for idx, msg in enumerate(history)
|
94 |
] + [{"role": "user", "content": message}]
|
95 |
|
96 |
+
# Stream a response from pipe
|
97 |
+
text = tokenizer.apply_chat_template(
|
98 |
+
history_transformer_format,
|
99 |
+
tokenize=False,
|
100 |
+
add_generation_prompt=True
|
101 |
+
)
|
102 |
+
model_inputs = tokenizer([text], return_tensors="pt")
|
103 |
+
|
104 |
+
generate_kwargs = dict(
|
105 |
+
model_inputs,
|
106 |
+
streamer=streamer,
|
107 |
+
max_new_tokens=512
|
108 |
+
)
|
109 |
+
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
|
110 |
+
t.start()
|
111 |
+
|
112 |
+
partial_message = ""
|
113 |
+
for new_token in streamer:
|
114 |
+
if new_token != '<':
|
115 |
+
partial_message += new_token
|
116 |
+
yield partial_message
|
117 |
+
|
118 |
+
yield bypass
|
119 |
|
120 |
|
121 |
# Create and run the gradio interface
|