shanaka95 commited on
Commit
f3660aa
·
verified ·
1 Parent(s): 81690c6

Update README.md

Browse files

Gradio code updated

Files changed (1) hide show
  1. README.md +25 -45
README.md CHANGED
@@ -99,16 +99,14 @@ Output: Bullish:38751.32,38818.6,38818.6,38695.03
99
  import sys
100
  import torch
101
  from peft import PeftModel
102
- import transformers
 
 
 
 
103
  import gradio as gr
104
 
105
- assert (
106
- "LlamaTokenizer" in transformers._import_structure["models.llama"]
107
- ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
108
- from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
109
-
110
-
111
- SHARE_GRADIO=True
112
  LOAD_8BIT = False
113
 
114
  BASE_MODEL = "mrzlab630/weights_Llama_7b"
@@ -116,16 +114,9 @@ LORA_WEIGHTS = "mrzlab630/lora-alpaca-trading-candles"
116
 
117
  tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
118
 
119
- if torch.cuda.is_available():
120
- device = "cuda"
121
- else:
122
- device = "cpu"
123
-
124
- try:
125
- if torch.backends.mps.is_available():
126
- device = "mps"
127
- except:
128
- pass
129
 
130
  if device == "cuda":
131
  model = LlamaForCausalLM.from_pretrained(
@@ -161,6 +152,12 @@ else:
161
  device_map={"": device},
162
  )
163
 
 
 
 
 
 
 
164
 
165
  def generate_prompt(instruction, input=None):
166
  if input:
@@ -181,14 +178,6 @@ def generate_prompt(instruction, input=None):
181
 
182
  ### Response:"""
183
 
184
- if not LOAD_8BIT:
185
- model.half() # seems to fix bugs for some users.
186
-
187
- model.eval()
188
- if torch.__version__ >= "2" and sys.platform != "win32":
189
- model = torch.compile(model)
190
-
191
-
192
  def evaluate(
193
  instruction,
194
  input=None,
@@ -221,31 +210,22 @@ def evaluate(
221
  output = tokenizer.decode(s)
222
  return output.split("### Response:")[1].strip()
223
 
224
-
225
  gr.Interface(
226
  fn=evaluate,
227
  inputs=[
228
- gr.components.Textbox(
229
- lines=2, label="Instruction", placeholder="Tell me about alpacas."
230
- ),
231
- gr.components.Textbox(lines=2, label="Input", placeholder="none"),
232
- gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
233
- gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
234
- gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
235
- gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
236
- gr.components.Slider(
237
- minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
238
- ),
239
- ],
240
- outputs=[
241
- gr.inputs.Textbox(
242
- lines=5,
243
- label="Output",
244
- )
245
  ],
 
246
  title="💹 🕯 Alpaca-LoRA-Trading-Candles",
247
- description="Alpaca-LoRA-Trading-Candles is a 7B-parameter LLaMA model tuned to execute instructions. It is trained on the [trading candles] dataset(https://huggingface.co/datasets/mrzlab630/trading-candles) and uses the Huggingface LLaMA implementation. For more information, visit [project website](https://huggingface.co/mrzlab630/lora-alpaca-trading-candles).\nPrompts:\nInstruction: identify candle, Input: open:241.5,close:232.9, high:241.7, low:230.8\nInstruction: find candle, Input: 38811.24,38838.41,38846.71,38736.24,234.00,45275276.00,59816.00,441285.00,645.00,84176.00,1694619.00,15732335.00\nInstruction: find candle: Bullish, Input: 38751.32,38818.6,38818.6,38695.03,62759348.00,2605789.00,71030.00,820738.00,59659.00,724738.00,7368363.00,50654.00",
248
  ).launch(server_name="0.0.0.0", share=SHARE_GRADIO)
249
 
250
 
 
251
  ```
 
99
  import sys
100
  import torch
101
  from peft import PeftModel
102
+ from transformers import (
103
+ LlamaTokenizer,
104
+ LlamaForCausalLM,
105
+ GenerationConfig
106
+ )
107
  import gradio as gr
108
 
109
+ SHARE_GRADIO = True
 
 
 
 
 
 
110
  LOAD_8BIT = False
111
 
112
  BASE_MODEL = "mrzlab630/weights_Llama_7b"
 
114
 
115
  tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
116
 
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ if torch.backends.mps.is_available():
119
+ device = "mps"
 
 
 
 
 
 
 
120
 
121
  if device == "cuda":
122
  model = LlamaForCausalLM.from_pretrained(
 
152
  device_map={"": device},
153
  )
154
 
155
+ if not LOAD_8BIT:
156
+ model.half()
157
+
158
+ model.eval()
159
+ if torch.__version__ >= "2" and sys.platform != "win32":
160
+ model = torch.compile(model)
161
 
162
  def generate_prompt(instruction, input=None):
163
  if input:
 
178
 
179
  ### Response:"""
180
 
 
 
 
 
 
 
 
 
181
  def evaluate(
182
  instruction,
183
  input=None,
 
210
  output = tokenizer.decode(s)
211
  return output.split("### Response:")[1].strip()
212
 
 
213
  gr.Interface(
214
  fn=evaluate,
215
  inputs=[
216
+ gr.Textbox(lines=2, label="Instruction", placeholder="Tell me about alpacas."),
217
+ gr.Textbox(lines=2, label="Input", placeholder="none"),
218
+ gr.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
219
+ gr.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
220
+ gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
221
+ gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
222
+ gr.Slider(minimum=1, maximum=2000, step=1, value=128, label="Max tokens"),
 
 
 
 
 
 
 
 
 
 
223
  ],
224
+ outputs=gr.Textbox(lines=5, label="Output"),
225
  title="💹 🕯 Alpaca-LoRA-Trading-Candles",
226
+ description="""Alpaca-LoRA-Trading-Candles is a 7B-parameter LLaMA model tuned to execute instructions. It is trained on the [trading candles](https://huggingface.co/datasets/mrzlab630/trading-candles) dataset and uses the Huggingface LLaMA implementation. For more information, visit [project website](https://huggingface.co/mrzlab630/lora-alpaca-trading-candles).""",
227
  ).launch(server_name="0.0.0.0", share=SHARE_GRADIO)
228
 
229
 
230
+
231
  ```