rag_chatbot / app.py
InspirationYF's picture
bugfix
62e4c6c
raw
history blame
2.31 kB
import os
import torch
import gradio as gr
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer
# 登录 Hugging Face API
api_token = os.environ.get("HF_API_TOKEN")
login(api_token)
# 模型加载函数
def get_llm(model_id):
# 使用 `device_map="auto"` 自动分配设备
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
return model
# 问答逻辑
def retriever_qa(file, query):
# 加载模型和分词器
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
# 确保 CUDA 初始化不在主线程
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
# 子进程中完成模型加载和推理
def process_inference(file, query):
# 加载模型
llm = get_llm(model_id)
# 加载文件的第一行内容
with open(file, 'r') as f:
first_line = f.readline()
# 准备输入
messages = [
{"role": "user", "content": first_line + query}
]
print(messages)
# Tokenize 输入
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
print('Start Inference')
# 推理
generated_ids = llm.generate(model_inputs['input_ids'], max_new_tokens=50, do_sample=True)
# 解码输出
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
# 调用推理逻辑
response = process_inference(file, query)
return response
# Gradio 界面
rag_application = gr.Interface(
fn=retriever_qa,
allow_flagging="never",
inputs=[
gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # 仅支持 TXT 文件
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") # 查询输入框
],
outputs=gr.Textbox(label="Output"), # 输出显示框
title="RAG Chatbot",
description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document."
)
# 启动 Gradio 应用
rag_application.launch(share=True)