Christoph Holthaus commited on
Commit
992ad15
·
1 Parent(s): 81c62e1

DEV: Use mistral template

Browse files
Files changed (1) hide show
  1. app.py +130 -58
app.py CHANGED
@@ -1,4 +1,5 @@
1
- # Importing libraries
 
2
  from llama_cpp import Llama
3
  from time import time
4
  import gradio as gr
@@ -20,63 +21,134 @@ It uses quantized models in gguf-Format and llama.cpp to run them.
20
  <br>''' + f"DEBUG: Memory used: {psutil.virtual_memory()[2]}<br>" + '''
21
  Powered by ...'''
22
 
23
- # Loading prompt
24
- prompt = ""
25
- system_message = ""
26
-
27
- def generate_answer(request: str, max_tokens: int = 256, custom_prompt: str = None):
28
- t0 = time()
29
- logs = f"Request: {request}\nMax tokens: {max_tokens}\nCustom prompt: {custom_prompt}\n"
30
- try:
31
- maxTokens = max_tokens if 16 <= max_tokens <= 256 else 64
32
- userPrompt = prompt.replace("{prompt}", request)
33
- userPrompt = userPrompt.replace(
34
- "{system_message}",
35
- custom_prompt if isinstance(custom_prompt, str) and len(custom_prompt.strip()) > 1 and custom_prompt.strip() not in ['', None, ' '] else system_message
36
- )
37
- logs += f"\nFinal prompt: {userPrompt}\n"
38
- except:
39
- return "Not enough data! Check that you passed all needed data.", logs
40
-
41
- try:
42
- # this shitty fix will be until i willnt figure out why sometimes there is empty output
43
- counter = 1
44
- while counter <= 3:
45
- logs += f"Attempt {counter} to generate answer...\n"
46
- output = llm(userPrompt, max_tokens=maxTokens, stop=["<|im_end|>"], echo=False)
47
- text = output["choices"][0]["text"]
48
- if len(text.strip()) > 1 and text.strip() not in ['', None, ' ']:
49
- break
50
- counter += 1
51
- logs += f"Final attempt: {counter}\n"
52
- if len(text.strip()) <= 1 or text.strip() in ['', None, ' ']:
53
- logs += f"Generated and aborted: {text}"
54
- text = "Sorry, but something went wrong while generating answer. Try again or fix code. If you are maintainer of this space, look into logs."
55
-
56
- logs += f"\nFinal: '''{text}'''"
57
- logs += f"\n\nTime spent: {time()-t0}"
58
- return text, logs
59
- except Exception as e:
60
- logs += str(e)
61
- logs += f"\n\nTime spent: {time()-t0}"
62
- return "Oops! Internal server error. Check the logs of space/instance.", logs
63
-
64
- print("! LOAD GRADIO INTERFACE !")
65
- demo = gr.Interface(
66
- fn=generate_answer,
67
- inputs=[
68
- gr.components.Textbox(label="Input"),
69
- gr.components.Number(value=256),
70
- gr.components.Textbox(label="Custom system prompt"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ],
72
- outputs=[
73
- gr.components.Textbox(label="Output"),
74
- gr.components.Textbox(label="Logs")
 
 
 
 
75
  ],
76
- title=title,
77
- description=desc,
78
- allow_flagging='never'
79
  )
80
- demo.queue()
81
- print("! LAUNCHING GRADIO !")
82
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python#
2
+
3
  from llama_cpp import Llama
4
  from time import time
5
  import gradio as gr
 
21
  <br>''' + f"DEBUG: Memory used: {psutil.virtual_memory()[2]}<br>" + '''
22
  Powered by ...'''
23
 
24
+
25
+
26
+ import os
27
+ from threading import Thread
28
+ from typing import Iterator
29
+
30
+ import gradio as gr
31
+ import spaces
32
+ import torch
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
34
+
35
+ DESCRIPTION = "# Mistral-7B"
36
+
37
+ if torch.cuda.is_available():
38
+ DESCRIPTION += "\n<p>This space is optimized for CPU only. Use another one if you want to go fast and use GPU. </p>"
39
+
40
+ MAX_MAX_NEW_TOKENS = 2048
41
+ DEFAULT_MAX_NEW_TOKENS = 1024
42
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
43
+
44
+
45
+ #download model here
46
+ # check localstorage, if no there, load, else use existing.
47
+
48
+ if torch.cuda.is_available():
49
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
50
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
51
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
52
+
53
+
54
+ def generate(
55
+ message: str,
56
+ chat_history: list[tuple[str, str]],
57
+ max_new_tokens: int = 1024,
58
+ temperature: float = 0.6,
59
+ top_p: float = 0.9,
60
+ top_k: int = 50,
61
+ repetition_penalty: float = 1.2,
62
+ ) -> Iterator[str]:
63
+ conversation = []
64
+ for user, assistant in chat_history:
65
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
66
+ conversation.append({"role": "user", "content": message})
67
+
68
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
69
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
70
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
71
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
72
+ input_ids = input_ids.to(model.device)
73
+
74
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
75
+ generate_kwargs = dict(
76
+ {"input_ids": input_ids},
77
+ streamer=streamer,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=True,
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ temperature=temperature,
83
+ num_beams=1,
84
+ repetition_penalty=repetition_penalty,
85
+ )
86
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
87
+ t.start()
88
+
89
+ outputs = []
90
+ for text in streamer:
91
+ outputs.append(text)
92
+ yield "".join(outputs)
93
+
94
+
95
+ chat_interface = gr.ChatInterface(
96
+ fn=generate,
97
+ additional_inputs=[
98
+ gr.Slider(
99
+ label="Max new tokens",
100
+ minimum=1,
101
+ maximum=MAX_MAX_NEW_TOKENS,
102
+ step=1,
103
+ value=DEFAULT_MAX_NEW_TOKENS,
104
+ ),
105
+ gr.Slider(
106
+ label="Temperature",
107
+ minimum=0.1,
108
+ maximum=4.0,
109
+ step=0.1,
110
+ value=0.6,
111
+ ),
112
+ gr.Slider(
113
+ label="Top-p (nucleus sampling)",
114
+ minimum=0.05,
115
+ maximum=1.0,
116
+ step=0.05,
117
+ value=0.9,
118
+ ),
119
+ gr.Slider(
120
+ label="Top-k",
121
+ minimum=1,
122
+ maximum=1000,
123
+ step=1,
124
+ value=50,
125
+ ),
126
+ gr.Slider(
127
+ label="Repetition penalty",
128
+ minimum=1.0,
129
+ maximum=2.0,
130
+ step=0.05,
131
+ value=1.2,
132
+ ),
133
  ],
134
+ stop_btn=None,
135
+ examples=[
136
+ ["Hello there! How are you doing?"],
137
+ ["Can you explain briefly to me what is the Python programming language?"],
138
+ ["Explain the plot of Cinderella in a sentence."],
139
+ ["How many hours does it take a man to eat a Helicopter?"],
140
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
141
  ],
 
 
 
142
  )
143
+
144
+ with gr.Blocks(css="style.css") as demo:
145
+ gr.Markdown(DESCRIPTION)
146
+ gr.DuplicateButton(
147
+ value="Duplicate Space for private use",
148
+ elem_id="duplicate-button",
149
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
150
+ )
151
+ chat_interface.render()
152
+
153
+ if __name__ == "__main__":
154
+ demo.queue(max_size=20).launch()