Seunggg commited on
Commit
75ecc06
·
verified ·
1 Parent(s): e45ed98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -22
app.py CHANGED
@@ -1,25 +1,57 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
 
5
- # 你自己的模型 repo
6
- model_id = "Seunggg/lora-plant"
7
-
8
- # 加载模型和 tokenizer
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
10
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto")
11
-
12
- # 定义接口函数
13
- def plant_chat(user_input):
14
- prompt = f"用户提问:{user_input}\n请用人性化语言回答,并推荐相关的植物资料或文献:\n回答:"
15
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
16
- outputs = model.generate(**inputs, max_new_tokens=256)
17
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
- return answer
19
-
20
- # 启动 Gradio 接口
21
- gr.Interface(fn=plant_chat,
22
- inputs="text",
23
- outputs="text",
24
- title="🌿 植物问答助手",
25
- description="根据你的问题,提供植物养护建议和文献线索。").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from peft import PeftModel
4
  import torch
5
 
6
+ model_id = "deepseek-ai/deepseek-coder-1.3b-base"
7
+ lora_id = "Seunggg/lora-plant"
8
+
9
+ # 加载 tokenizer
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
+
12
+ # 加载基础模型,启用自动设备分配并脱载
13
+ base = AutoModelForCausalLM.from_pretrained(
14
+ model_id,
15
+ device_map="auto",
16
+ offload_folder="offload/",
17
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
+ trust_remote_code=True
19
+ )
20
+
21
+ # 加载 LoRA adapter,同样启用脱载
22
+ model = PeftModel.from_pretrained(
23
+ base,
24
+ lora_id,
25
+ offload_folder="offload/",
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
+ )
28
+
29
+ model.eval()
30
+
31
+ # 生成 pipeline
32
+ from transformers import pipeline
33
+ pipe = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tokenizer,
37
+ device_map="auto",
38
+ max_new_tokens=256
39
+ )
40
+
41
+ def respond(user_input):
42
+ if not user_input.strip():
43
+ return "请输入植物相关的问题 :)"
44
+ prompt = f"用户提问:{user_input}\n请用更人性化的语言生成建议,并推荐相关植物文献或资料。\n回答:"
45
+ result = pipe(prompt)
46
+ return result[0]["generated_text"]
47
+
48
+ # Gradio 界面
49
+ gr.Interface(
50
+ fn=respond,
51
+ inputs=gr.Textbox(lines=4, placeholder="在这里输入你的植物问题..."),
52
+ outputs="text",
53
+ title="🌱 植物助手 LoRA 版",
54
+ description="基于 DeepSeek 微调模型,提供植物养护建议和文献推荐。",
55
+ allow_flagging="never"
56
+ ).launch()
57
+