p1atdev commited on
Commit
442b8aa
·
verified ·
1 Parent(s): 008babf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -25,21 +25,17 @@ from transformers import (
25
  from threading import Thread
26
 
27
  import gradio as gr
28
- from dotenv import load_dotenv
29
 
30
  import spaces
31
 
32
 
33
- load_dotenv()
34
-
35
- HF_API_KEY = os.getenv("HF_API_KEY")
36
  MODEL_NAME_MAP = {
37
  "150m-instruct3": "llm-jp/llm-jp-3-150m-instruct3",
38
  "440m-instruct3": "llm-jp/llm-jp-3-440m-instruct3",
39
  "980m-instruct3": "llm-jp/llm-jp-3-980m-instruct3",
40
- # "1.8b-instruct3": "llm-jp/llm-jp-3-1.8b-instruct3",
41
- # "3.7b-instruct3": "llm-jp/llm-jp-3-3.7b-instruct3",
42
- # "13b-instruct3": "llm-jp/llm-jp-3-13b-instruct3",
43
  }
44
 
45
  quantization_config = BitsAndBytesConfig(
@@ -50,12 +46,14 @@ quantization_config = BitsAndBytesConfig(
50
  )
51
  MODELS = {
52
  key: AutoModelForCausalLM.from_pretrained(
53
- repo_id, quantization_config=quantization_config, device_map="auto"
 
 
 
54
  ) for key, repo_id in MODEL_NAME_MAP.items()
55
  }
56
  TOKENIZERS = {
57
  key: AutoTokenizer.from_pretrained(repo_id) for key, repo_id in MODEL_NAME_MAP.items()
58
-
59
  }
60
 
61
  print("Compiling model...")
 
25
  from threading import Thread
26
 
27
  import gradio as gr
 
28
 
29
  import spaces
30
 
31
 
 
 
 
32
  MODEL_NAME_MAP = {
33
  "150m-instruct3": "llm-jp/llm-jp-3-150m-instruct3",
34
  "440m-instruct3": "llm-jp/llm-jp-3-440m-instruct3",
35
  "980m-instruct3": "llm-jp/llm-jp-3-980m-instruct3",
36
+ "1.8b-instruct3": "llm-jp/llm-jp-3-1.8b-instruct3",
37
+ "3.7b-instruct3": "llm-jp/llm-jp-3-3.7b-instruct3",
38
+ "13b-instruct3": "llm-jp/llm-jp-3-13b-instruct3",
39
  }
40
 
41
  quantization_config = BitsAndBytesConfig(
 
46
  )
47
  MODELS = {
48
  key: AutoModelForCausalLM.from_pretrained(
49
+ repo_id,
50
+ quantization_config=quantization_config,
51
+ device_map="auto",
52
+ attn_implementation="flash_attention_2",
53
  ) for key, repo_id in MODEL_NAME_MAP.items()
54
  }
55
  TOKENIZERS = {
56
  key: AutoTokenizer.from_pretrained(repo_id) for key, repo_id in MODEL_NAME_MAP.items()
 
57
  }
58
 
59
  print("Compiling model...")