LexAIcon / app.py
manuelcozar55's picture
Update app.py
a876827 verified
raw
history blame
8.46 kB
import streamlit as st
from huggingface_hub import snapshot_download
from pathlib import Path
from mistral_inference.model import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from PyPDF2 import PdfReader
from docx import Document
import csv
import json
import os
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
# Descargar y configurar el modelo
mistral_models_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/Mistral-7B-Instruct-v0.3", allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], local_dir=mistral_models_path)
# Configurar el modelo y el tokenizador
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3")
model = Transformer.from_folder(mistral_models_path)
# Configuraci贸n del modelo de clasificaci贸n
@st.cache_resource
def load_classification_model():
tokenizer_cls = AutoTokenizer.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
model_cls = AutoModelForSequenceClassification.from_pretrained("mrm8488/legal-longformer-base-8192-spanish")
return model_cls, tokenizer_cls
classification_model, classification_tokenizer = load_classification_model()
id2label = {0: "multas", 1: "politicas_de_privacidad", 2: "contratos", 3: "denuncias", 4: "otros"}
def classify_text(text):
inputs = classification_tokenizer(text, return_tensors="pt", max_length=4096, truncation=True, padding="max_length")
classification_model.eval()
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax(dim=-1).item()
predicted_label = id2label[predicted_class_id]
return predicted_label
def load_json_documents(category):
with open(f"./{category}.json", "r", encoding="utf-8") as f:
data = json.load(f)["questions_and_answers"]
documents = [entry["question"] + " " + entry["answer"] for entry in data]
return documents
def create_vector_store(docs):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-l6-v2", model_kwargs={"device": "cpu"})
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
split_docs = text_splitter.split_text(docs)
vector_store = FAISS.from_texts(split_docs, embeddings)
return vector_store
def translate(text, target_language):
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=f"Por favor, traduzca el siguiente documento al {target_language}:\n{text}\nAseg煤rese de que la traducci贸n sea precisa y conserve el significado original del documento.")]
)
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens, _ = generate([tokens], model, max_tokens=512, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
translated_text = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
return translated_text
def summarize(text, length):
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=f"Por favor, haga un resumen {length} del siguiente documento:\n{text}\nAseg煤rese de que el resumen sea conciso y conserve el significado original del documento.")]
)
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens, _ = generate([tokens], model, max_tokens=512, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
summarized_text = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
return summarized_text
def handle_uploaded_file(uploaded_file):
try:
if uploaded_file.name.endswith(".txt"):
text = uploaded_file.read().decode("utf-8")
elif uploaded_file.name.endswith(".pdf"):
reader = PdfReader(uploaded_file)
text = ""
for page in range(len(reader.pages)):
text += reader.pages[page].extract_text()
elif uploaded_file.name.endswith(".docx"):
doc = Document(uploaded_file)
text = "\n".join([para.text for para in doc.paragraphs])
elif uploaded_file.name.endswith(".csv"):
text = ""
content = uploaded_file.read().decode("utf-8").splitlines()
reader = csv.reader(content)
text = " ".join([" ".join(row) for row in reader])
elif uploaded_file.name.endswith(".json"):
data = json.load(uploaded_file)
text = json.dumps(data, indent=4)
else:
text = "Tipo de archivo no soportado."
return text
except Exception as e:
return str(e)
def main():
st.title("LexAIcon")
st.write("Puedes conversar con este chatbot basado en Mistral-7B-Instruct y subir archivos para que el chatbot los procese.")
if "messages" not in st.session_state:
st.session_state["messages"] = []
with st.sidebar:
st.text_input("HuggingFace Token", value=huggingface_token, type="password", key="huggingface_token")
st.caption("[Consigue un HuggingFace Token](https://huggingface.co/settings/tokens)")
user_input = st.text_input("Introduce tu consulta:", "")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
operation = st.radio("Selecciona una operaci贸n", ["Resumir", "Traducir", "Explicar"])
target_language = None
summary_length = None
if operation == "Traducir":
target_language = st.selectbox("Selecciona el idioma de traducci贸n", ["espa帽ol", "ingl茅s", "franc茅s", "alem谩n"])
if operation == "Resumir":
summary_length = st.selectbox("Selecciona la longitud del resumen", ["corto", "medio", "largo"])
if uploaded_files := st.file_uploader("Sube un archivo", type=["txt", "pdf", "docx", "csv", "json"], accept_multiple_files=True):
for uploaded_file in uploaded_files:
file_content = handle_uploaded_file(uploaded_file)
classification = classify_text(file_content)
docs = load_json_documents(classification)
vector_store = create_vector_store(docs)
search_docs = vector_store.similarity_search(user_input)
context = " ".join([doc.page_content for doc in search_docs])
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=f"Contexto: {context}\n\nPregunta: {user_input}")]
)
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens, _ = generate([tokens], model, max_tokens=512, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
bot_response = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
elif operation == "Resumir":
if summary_length == "corto":
length = "de aproximadamente 50 palabras"
elif summary_length == "medio":
length = "de aproximadamente 100 palabras"
elif summary_length == "largo":
length = "de aproximadamente 500 palabras"
bot_response = summarize(user_input, length)
elif operation == "Traducir":
bot_response = translate(user_input, target_language)
else:
completion_request = ChatCompletionRequest(
messages=[UserMessage(content=user_input)]
)
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens, _ = generate([tokens], model, max_tokens=512, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
bot_response = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
st.session_state.messages.append({"role": "assistant", "content": bot_response})
st.write(f"**Assistant:** {bot_response}")
if __name__ == "__main__":
main()