File size: 3,438 Bytes
2a0f243
cf47d83
69e3a41
3d477e1
ea3c34e
69e3a41
2a0f243
85ec4d4
 
 
2a0f243
763be08
ea3c34e
 
a77f2b3
 
338b938
70ed6f0
338b938
b99eeda
a3bc7ec
763be08
 
 
 
338b938
ea3c34e
 
 
 
338b938
ea3c34e
 
 
a3bc7ec
763be08
 
 
338b938
904cf6c
763be08
 
 
 
 
338b938
 
 
 
763be08
 
 
5d3dc3e
ea3c34e
b99eeda
ea3c34e
 
 
 
 
763be08
 
 
 
 
 
 
105c4c8
338b938
763be08
 
 
 
 
338b938
763be08
 
 
 
2b5a681
ea3c34e
 
 
b99eeda
ea3c34e
 
763be08
 
2b5a681
ea3c34e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import streamlit as st
from huggingface_hub import login
import pandas as pd
from threading import Thread

# Token Secret de Hugging Face
huggingface_token = st.secrets["HUGGINGFACEHUB_API_TOKEN"]
login(huggingface_token)

# Cargar el tokenizador y el modelo
model_id = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer.pad_token = tokenizer.eos_token

MAX_INPUT_TOKEN_LENGTH = 4096

def generate_response(input_text, temperature=0.5, max_new_tokens=20):
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(model.device)

    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        st.warning(f"Se recort贸 la entrada porque excedi贸 el l铆mite de {MAX_INPUT_TOKEN_LENGTH} tokens.")

    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        num_beams=3,  # Usar beam search
        temperature=temperature,
        eos_token_id=[tokenizer.eos_token_id]
    )

    try:
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()
        t.join()  # Asegura que la generaci贸n haya terminado

        outputs = []
        for text in streamer:
            outputs.append(text)
        if not outputs:
            raise ValueError("No se gener贸 ninguna respuesta.")
        
        # Post-procesamiento m谩s restrictivo
        response = "".join(outputs).strip().split("\n")[0]
        return response
    except Exception as e:
        st.error(f"Error durante la generaci贸n: {e}")
        return "Error en la generaci贸n de texto."

def main():
    st.title("Chat con Meta Llama 3.2 1B")
    
    uploaded_file = st.file_uploader("Por favor, sube un archivo CSV para iniciar:", type=["csv"])
    
    if uploaded_file is not None:
        df = pd.read_csv(uploaded_file)
        
        if 'job_title' in df.columns:
            job_titles = df['job_title'].tolist()
            query = "aspiring human resources specialist"
            
            st.write("Archivo CSV cargado exitosamente:")
            st.write(df.head())

            initial_prompt = f"The list of job titles is: {job_titles}. Extract only the first job title from the list and return it as the answer."
            st.write(f"Query: {query}")
            st.write(f"Prompt inicial: {initial_prompt}")

            if st.button("Generar respuesta"):
                with st.spinner("Generando respuesta..."):
                    response = generate_response(initial_prompt, temperature=0.5)
                    if response:
                        st.write(f"Respuesta del modelo: {response}")
                    else:
                        st.warning("No se pudo generar una respuesta.")

                st.success("La conversaci贸n ha terminado.")
                
                if st.button("Iniciar nueva conversaci贸n"):
                    st.experimental_rerun()
                elif st.button("Terminar"):
                    st.stop()
        else:
            st.error("La columna 'job_title' no se encuentra en el archivo CSV.")

if __name__ == "__main__":
    main()