JaphetHernandez commited on
Commit
338b938
verified
1 Parent(s): 6eb5316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -14,23 +14,21 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
15
  tokenizer.pad_token = tokenizer.eos_token
16
 
17
- MAX_INPUT_TOKEN_LENGTH = 10000
18
 
19
- def generate_response(input_text, temperature=0.5, max_new_tokens=100):
20
  input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)
21
 
22
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
23
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
24
  st.warning(f"Se recort贸 la entrada porque excedi贸 el l铆mite de {MAX_INPUT_TOKEN_LENGTH} tokens.")
25
 
26
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
27
  generate_kwargs = dict(
28
  input_ids=input_ids,
29
  streamer=streamer,
30
  max_new_tokens=max_new_tokens,
31
- do_sample=True,
32
- top_k=50,
33
- top_p=0.9,
34
  temperature=temperature,
35
  eos_token_id=[tokenizer.eos_token_id]
36
  )
@@ -38,14 +36,17 @@ def generate_response(input_text, temperature=0.5, max_new_tokens=100):
38
  try:
39
  t = Thread(target=model.generate, kwargs=generate_kwargs)
40
  t.start()
41
- t.join() # Esperar a que el hilo termine
42
 
43
  outputs = []
44
  for text in streamer:
45
  outputs.append(text)
46
  if not outputs:
47
  raise ValueError("No se gener贸 ninguna respuesta.")
48
- return "".join(outputs)
 
 
 
49
  except Exception as e:
50
  st.error(f"Error durante la generaci贸n: {e}")
51
  return "Error en la generaci贸n de texto."
@@ -65,13 +66,13 @@ def main():
65
  st.write("Archivo CSV cargado exitosamente:")
66
  st.write(df.head())
67
 
68
- initial_prompt = f"I have a list of job titles: {job_titles}. Please extract and return only the first job title from this list without repeating."
69
  st.write(f"Query: {query}")
70
  st.write(f"Prompt inicial: {initial_prompt}")
71
 
72
  if st.button("Generar respuesta"):
73
  with st.spinner("Generando respuesta..."):
74
- response = generate_response(initial_prompt, temperature=0.2)
75
  if response:
76
  st.write(f"Respuesta del modelo: {response}")
77
  else:
 
14
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
15
  tokenizer.pad_token = tokenizer.eos_token
16
 
17
+ MAX_INPUT_TOKEN_LENGTH = 4096
18
 
19
+ def generate_response(input_text, temperature=0.5, max_new_tokens=20):
20
  input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)
21
 
22
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
23
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
24
  st.warning(f"Se recort贸 la entrada porque excedi贸 el l铆mite de {MAX_INPUT_TOKEN_LENGTH} tokens.")
25
 
26
+ streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
27
  generate_kwargs = dict(
28
  input_ids=input_ids,
29
  streamer=streamer,
30
  max_new_tokens=max_new_tokens,
31
+ num_beams=3, # Usar beam search
 
 
32
  temperature=temperature,
33
  eos_token_id=[tokenizer.eos_token_id]
34
  )
 
36
  try:
37
  t = Thread(target=model.generate, kwargs=generate_kwargs)
38
  t.start()
39
+ t.join() # Asegura que la generaci贸n haya terminado
40
 
41
  outputs = []
42
  for text in streamer:
43
  outputs.append(text)
44
  if not outputs:
45
  raise ValueError("No se gener贸 ninguna respuesta.")
46
+
47
+ # Post-procesamiento m谩s restrictivo
48
+ response = "".join(outputs).strip().split("\n")[0]
49
+ return response
50
  except Exception as e:
51
  st.error(f"Error durante la generaci贸n: {e}")
52
  return "Error en la generaci贸n de texto."
 
66
  st.write("Archivo CSV cargado exitosamente:")
67
  st.write(df.head())
68
 
69
+ initial_prompt = f"The list of job titles is: {job_titles}. Extract only the first job title from the list and return it as the answer."
70
  st.write(f"Query: {query}")
71
  st.write(f"Prompt inicial: {initial_prompt}")
72
 
73
  if st.button("Generar respuesta"):
74
  with st.spinner("Generando respuesta..."):
75
+ response = generate_response(initial_prompt, temperature=0.5)
76
  if response:
77
  st.write(f"Respuesta del modelo: {response}")
78
  else: