Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -211,7 +211,7 @@ EXAMPLES = [
|
|
211 |
|
212 |
# Definir la función simple_chat
|
213 |
@spaces.GPU()
|
214 |
-
def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
|
215 |
# Cargar el modelo preentrenado
|
216 |
model = AutoModelForCausalLM.from_pretrained(
|
217 |
MODEL_ID,
|
@@ -219,7 +219,7 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
|
|
219 |
low_cpu_mem_usage=True,
|
220 |
trust_remote_code=True
|
221 |
)
|
222 |
-
|
223 |
conversation = []
|
224 |
|
225 |
if "file" in message and message["file"]:
|
@@ -236,12 +236,10 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
|
|
236 |
# Preparar entrada para el modelo
|
237 |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
|
238 |
return_tensors="pt", return_dict=True).to(model.device)
|
239 |
-
|
240 |
-
|
241 |
# Configurar parámetros de generación
|
242 |
generate_kwargs = dict(
|
243 |
max_length=max_length,
|
244 |
-
streamer=streamer,
|
245 |
do_sample=True,
|
246 |
top_p=top_p,
|
247 |
top_k=top_k,
|
@@ -249,19 +247,13 @@ def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096,
|
|
249 |
repetition_penalty=penalty,
|
250 |
eos_token_id=[151329, 151336, 151338],
|
251 |
)
|
252 |
-
gen_kwargs = {**input_ids, **generate_kwargs}
|
253 |
|
254 |
-
# Generar respuesta
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
for new_text in streamer:
|
261 |
-
buffer += new_text
|
262 |
-
yield buffer.encode('utf-8')
|
263 |
-
|
264 |
-
#return StreamingResponse(generate(), media_type="text/plain")
|
265 |
return PlainTextResponse(generated_text)
|
266 |
|
267 |
@app.post("/chat/")
|
|
|
211 |
|
212 |
# Definir la función simple_chat
|
213 |
@spaces.GPU()
|
214 |
+
async def simple_chat(message: dict, temperature: float = 0.8, max_length: int = 4096, top_p: float = 1, top_k: int = 10, penalty: float = 1.0):
|
215 |
# Cargar el modelo preentrenado
|
216 |
model = AutoModelForCausalLM.from_pretrained(
|
217 |
MODEL_ID,
|
|
|
219 |
low_cpu_mem_usage=True,
|
220 |
trust_remote_code=True
|
221 |
)
|
222 |
+
|
223 |
conversation = []
|
224 |
|
225 |
if "file" in message and message["file"]:
|
|
|
236 |
# Preparar entrada para el modelo
|
237 |
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
|
238 |
return_tensors="pt", return_dict=True).to(model.device)
|
239 |
+
|
|
|
240 |
# Configurar parámetros de generación
|
241 |
generate_kwargs = dict(
|
242 |
max_length=max_length,
|
|
|
243 |
do_sample=True,
|
244 |
top_p=top_p,
|
245 |
top_k=top_k,
|
|
|
247 |
repetition_penalty=penalty,
|
248 |
eos_token_id=[151329, 151336, 151338],
|
249 |
)
|
|
|
250 |
|
251 |
+
# Generar respuesta
|
252 |
+
with torch.no_grad():
|
253 |
+
generated_ids = model.generate(input_ids['input_ids'], **generate_kwargs)
|
254 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
255 |
+
|
256 |
+
# Devolver la respuesta completa
|
|
|
|
|
|
|
|
|
|
|
257 |
return PlainTextResponse(generated_text)
|
258 |
|
259 |
@app.post("/chat/")
|