InspirationYF commited on
Commit
62e4c6c
·
1 Parent(s): 8f56420
Files changed (2) hide show
  1. app.py +43 -40
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,68 +1,71 @@
1
  import os
2
- # import torch
3
- import spaces
4
  import gradio as gr
5
  from huggingface_hub import login
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
 
8
  api_token = os.environ.get("HF_API_TOKEN")
9
  login(api_token)
10
 
11
-
12
- # You can use this section to suppress warnings generated by your code:
13
- # def warn(*args, **kwargs):
14
- # pass
15
- # import warnings
16
- # warnings.warn = warn
17
- # warnings.filterwarnings('ignore')
18
-
19
  def get_llm(model_id):
20
- model = AutoModelForCausalLM.from_pretrained(model_id)
21
- model.to('cuda')
22
  return model
23
 
24
- @spaces.GPU
25
  def retriever_qa(file, query):
 
26
  model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
27
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
28
- llm = get_llm(model_id)
29
- # retriever_obj = retriever(file)
30
- # qa = RetrievalQA.from_chain_type(llm=llm,
31
- # chain_type="stuff",
32
- # retriever=retriever_obj,
33
- # return_source_documents=False)
34
- # response = qa.invoke(query)
35
- with open(file, 'r') as f:
36
- first_line = f.readline()
37
 
38
- messages = [
39
- {"role": "user", "content": first_line + query}
40
- ]
41
- print(messages)
42
- model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
43
- print('Start Inference')
44
- generated_ids = llm.generate(model_inputs, max_new_tokens=50, do_sample=True)
45
- response = generated_ids
46
- # print('Start detokenize')
47
- # response = tokenizer.batch_decode(generated_ids)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # # Check if a GPU is available
50
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
- # response = response + f". Using device: {device}"
52
-
53
  return response
54
 
 
55
  rag_application = gr.Interface(
56
  fn=retriever_qa,
57
  allow_flagging="never",
58
  inputs=[
59
- # gr.File(label="Upload PDF File", file_count="single", file_types=['.pdf'], type="filepath"), # Drag and drop file upload
60
- gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # Drag and drop file upload
61
- gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...")
62
  ],
63
- outputs=gr.Textbox(label="Output"),
64
  title="RAG Chatbot",
65
  description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document."
66
  )
67
 
 
68
  rag_application.launch(share=True)
 
1
  import os
2
+ import torch
 
3
  import gradio as gr
4
  from huggingface_hub import login
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
+ # 登录 Hugging Face API
8
  api_token = os.environ.get("HF_API_TOKEN")
9
  login(api_token)
10
 
11
+ # 模型加载函数
 
 
 
 
 
 
 
12
  def get_llm(model_id):
13
+ # 使用 `device_map="auto"` 自动分配设备
14
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
15
  return model
16
 
17
+ # 问答逻辑
18
  def retriever_qa(file, query):
19
+ # 加载模型和分词器
20
  model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
21
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
 
 
 
 
 
 
 
 
 
22
 
23
+ # 确保 CUDA 初始化不在主线程
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ print(f'Device: {device}')
26
+
27
+ # 子进程中完成模型加载和推理
28
+ def process_inference(file, query):
29
+ # 加载模型
30
+ llm = get_llm(model_id)
31
+
32
+ # 加载文件的第一行内容
33
+ with open(file, 'r') as f:
34
+ first_line = f.readline()
35
+
36
+ # 准备输入
37
+ messages = [
38
+ {"role": "user", "content": first_line + query}
39
+ ]
40
+ print(messages)
41
+
42
+ # Tokenize 输入
43
+ model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
44
+ print('Start Inference')
45
+
46
+ # 推理
47
+ generated_ids = llm.generate(model_inputs['input_ids'], max_new_tokens=50, do_sample=True)
48
+
49
+ # 解码输出
50
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
51
+ return response
52
 
53
+ # 调用推理逻辑
54
+ response = process_inference(file, query)
 
 
55
  return response
56
 
57
+ # Gradio 界面
58
  rag_application = gr.Interface(
59
  fn=retriever_qa,
60
  allow_flagging="never",
61
  inputs=[
62
+ gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # 仅支持 TXT 文件
63
+ gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") # 查询输入框
 
64
  ],
65
+ outputs=gr.Textbox(label="Output"), # 输出显示框
66
  title="RAG Chatbot",
67
  description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document."
68
  )
69
 
70
+ # 启动 Gradio 应用
71
  rag_application.launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  transformers==4.36.0
2
- sentencepiece
 
 
1
  transformers==4.36.0
2
+ sentencepiece
3
+ accelerate