Spaces:
Sleeping
Sleeping
Commit
·
62e4c6c
1
Parent(s):
8f56420
bugfix
Browse files- app.py +43 -40
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,68 +1,71 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
import spaces
|
4 |
import gradio as gr
|
5 |
from huggingface_hub import login
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
|
|
|
8 |
api_token = os.environ.get("HF_API_TOKEN")
|
9 |
login(api_token)
|
10 |
|
11 |
-
|
12 |
-
# You can use this section to suppress warnings generated by your code:
|
13 |
-
# def warn(*args, **kwargs):
|
14 |
-
# pass
|
15 |
-
# import warnings
|
16 |
-
# warnings.warn = warn
|
17 |
-
# warnings.filterwarnings('ignore')
|
18 |
-
|
19 |
def get_llm(model_id):
|
20 |
-
|
21 |
-
model.
|
22 |
return model
|
23 |
|
24 |
-
|
25 |
def retriever_qa(file, query):
|
|
|
26 |
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
28 |
-
llm = get_llm(model_id)
|
29 |
-
# retriever_obj = retriever(file)
|
30 |
-
# qa = RetrievalQA.from_chain_type(llm=llm,
|
31 |
-
# chain_type="stuff",
|
32 |
-
# retriever=retriever_obj,
|
33 |
-
# return_source_documents=False)
|
34 |
-
# response = qa.invoke(query)
|
35 |
-
with open(file, 'r') as f:
|
36 |
-
first_line = f.readline()
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
#
|
50 |
-
|
51 |
-
# response = response + f". Using device: {device}"
|
52 |
-
|
53 |
return response
|
54 |
|
|
|
55 |
rag_application = gr.Interface(
|
56 |
fn=retriever_qa,
|
57 |
allow_flagging="never",
|
58 |
inputs=[
|
59 |
-
|
60 |
-
gr.
|
61 |
-
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...")
|
62 |
],
|
63 |
-
outputs=gr.Textbox(label="Output"),
|
64 |
title="RAG Chatbot",
|
65 |
description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document."
|
66 |
)
|
67 |
|
|
|
68 |
rag_application.launch(share=True)
|
|
|
1 |
import os
|
2 |
+
import torch
|
|
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import login
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
|
7 |
+
# 登录 Hugging Face API
|
8 |
api_token = os.environ.get("HF_API_TOKEN")
|
9 |
login(api_token)
|
10 |
|
11 |
+
# 模型加载函数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def get_llm(model_id):
|
13 |
+
# 使用 `device_map="auto"` 自动分配设备
|
14 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
15 |
return model
|
16 |
|
17 |
+
# 问答逻辑
|
18 |
def retriever_qa(file, query):
|
19 |
+
# 加载模型和分词器
|
20 |
model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# 确保 CUDA 初始化不在主线程
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
print(f'Device: {device}')
|
26 |
+
|
27 |
+
# 子进程中完成模型加载和推理
|
28 |
+
def process_inference(file, query):
|
29 |
+
# 加载模型
|
30 |
+
llm = get_llm(model_id)
|
31 |
+
|
32 |
+
# 加载文件的第一行内容
|
33 |
+
with open(file, 'r') as f:
|
34 |
+
first_line = f.readline()
|
35 |
+
|
36 |
+
# 准备输入
|
37 |
+
messages = [
|
38 |
+
{"role": "user", "content": first_line + query}
|
39 |
+
]
|
40 |
+
print(messages)
|
41 |
+
|
42 |
+
# Tokenize 输入
|
43 |
+
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
44 |
+
print('Start Inference')
|
45 |
+
|
46 |
+
# 推理
|
47 |
+
generated_ids = llm.generate(model_inputs['input_ids'], max_new_tokens=50, do_sample=True)
|
48 |
+
|
49 |
+
# 解码输出
|
50 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
51 |
+
return response
|
52 |
|
53 |
+
# 调用推理逻辑
|
54 |
+
response = process_inference(file, query)
|
|
|
|
|
55 |
return response
|
56 |
|
57 |
+
# Gradio 界面
|
58 |
rag_application = gr.Interface(
|
59 |
fn=retriever_qa,
|
60 |
allow_flagging="never",
|
61 |
inputs=[
|
62 |
+
gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # 仅支持 TXT 文件
|
63 |
+
gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") # 查询输入框
|
|
|
64 |
],
|
65 |
+
outputs=gr.Textbox(label="Output"), # 输出显示框
|
66 |
title="RAG Chatbot",
|
67 |
description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document."
|
68 |
)
|
69 |
|
70 |
+
# 启动 Gradio 应用
|
71 |
rag_application.launch(share=True)
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
transformers==4.36.0
|
2 |
-
sentencepiece
|
|
|
|
1 |
transformers==4.36.0
|
2 |
+
sentencepiece
|
3 |
+
accelerate
|