InspirationYF commited on
Commit
de90557
·
1 Parent(s): 4e96b59

feat: add mistral

Browse files
Files changed (1) hide show
  1. app.py +31 -6
app.py CHANGED
@@ -1,8 +1,12 @@
1
- import torch
 
2
 
3
- # Check if a GPU is available
4
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
- print(f"Using device: {device}")
 
 
 
6
 
7
  import gradio as gr
8
 
@@ -13,8 +17,15 @@ import warnings
13
  warnings.warn = warn
14
  warnings.filterwarnings('ignore')
15
 
 
 
 
 
 
 
 
16
  def retriever_qa(file, query):
17
- # llm = get_llm()
18
  # retriever_obj = retriever(file)
19
  # qa = RetrievalQA.from_chain_type(llm=llm,
20
  # chain_type="stuff",
@@ -23,8 +34,22 @@ def retriever_qa(file, query):
23
  # response = qa.invoke(query)
24
  with open(file, 'r') as f:
25
  first_line = f.readline()
 
 
 
 
 
 
 
 
 
26
 
27
- response = first_line + query
 
 
 
 
 
28
 
29
  return response
30
 
 
1
+ import spaces
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="auto")
5
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
6
+
7
+ # # Check if a GPU is available
8
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ # print(f"Using device: {device}")
10
 
11
  import gradio as gr
12
 
 
17
  warnings.warn = warn
18
  warnings.filterwarnings('ignore')
19
 
20
+ def get_llm():
21
+ model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
22
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
23
+ model.to('cuda')
24
+ return model
25
+
26
+ @spaces.GPU
27
  def retriever_qa(file, query):
28
+ llm = get_llm()
29
  # retriever_obj = retriever(file)
30
  # qa = RetrievalQA.from_chain_type(llm=llm,
31
  # chain_type="stuff",
 
34
  # response = qa.invoke(query)
35
  with open(file, 'r') as f:
36
  first_line = f.readline()
37
+
38
+ messages = [
39
+ {"role": "user", "content": first_line}
40
+ ]
41
+
42
+ model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
43
+
44
+ generated_ids = llm.generate(model_inputs, max_new_tokens=100, do_sample=True)
45
+ # tokenizer.batch_decode(generated_ids)[0]
46
 
47
+ response = tokenizer.batch_decode(generated_ids)[0]
48
+
49
+ # # Check if a GPU is available
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ # print(f"Using device: {device}")
52
+ response = response + f". Using device: {device}"
53
 
54
  return response
55