chatbot_CDS / model.py
namngo's picture
Update model.py
10cc155 verified
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer,AutoModelForSeq2SeqLM
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from dataset import dataset_a
def load_model(model_name):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
def generate_answer(model, tokenizer, context, question, max_length=256):
input_text = f'context: {context} question: {question}'
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs, max_length=max_length)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# if __name__ == "__main__":
# model_name = "D:/Pycharm/Project/Project2/Model/vit5/vit5_base.zip/checkpoint-900" # Thay username và model_name bằng tên mô hình trên Hugging Face
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# print("Mô hình đã load thành công!")
#
# # Nhập dữ liệu để test
# context = """"Công nghệ thông tin đang phát triển mạnh mẽ và trở thành lĩnh vực không thế thiếu trong cuộc sống hiện đại. Thời kỳ hiện nay còn được gọi là thời đại kỹ thuật số, nơi công nghệ luôn thay đổi và phát triển nhanh chóng. Những tiến bộ khoa học công nghệ trong thế kỷ 21 đã tạo ra nhu cầu đào tạo những công dân chúng ta trở thành những công dân số. Công dân số là những người có kỹ năng khai thác, sử dụng Internet và công nghệ một cách an toàn và hiệu quả. Điều này không chỉ đế giải trí mà còn tìm kiếm thông tin, học tập, chia sẻ kiến thức, truyền thông, cũng như tìm hiếu kiến thức và pháp luật.
# Chương 1 cung cấp kiến thức tống quát về thế giới số, công dân số, các yếu tố và kỹ năng cần thiết với công dân số. Những nội dung về chuyến đổi số, số hóa, chữ ký số, chính phủ số, chỉnh phủ điện tử, văn hóa, đạo đức và pháp luật trong thế giới số. Nội dung chính của chương bao gồm:
# - Thế giới số;
# - Công dân số;
# - Chuyển đổi số;
# - Chỉnh phủ điện tử và chính phủ số;
# - Văn hóa, đạo đức và pháp luật trong thế giới số."
# """
# question = "Công nghệ thông tin đang phát triển như thế nào?"
#
# # Sinh câu trả lời
# answer = generate_answer(model, tokenizer, context, question)
# print(f"Answer: {answer}")
# print(model)
def find_context(pos_sentences,question,model,embedings='similarity_embeddings.npz'):
data = np.load(embedings)
pos_embeddings = data["embeddings"]
query_embedding = model.encode(question)
similarities = cosine_similarity([query_embedding], pos_embeddings)
most_similar_idx = np.argmax(similarities)
return pos_sentences[most_similar_idx]