Tri4 commited on
Commit
9d3365a
·
verified ·
1 Parent(s): 6ab5056

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +72 -4
main.py CHANGED
@@ -12,19 +12,27 @@ app = Flask(__name__)
12
 
13
  print("Hello welcome to Sema AI", flush=True) # Flush to ensure immediate output
14
 
 
 
 
 
15
  # Get Hugging Face credentials from environment variables
16
  email = os.getenv('HF_EMAIL')
17
  password = os.getenv('HF_PASS')
18
  GEMMA_TOKEN = os.getenv("GEMMA_TOKEN")
19
  #print(f"email is {email} and password is {password}", flush=True)
20
 
 
 
 
 
21
  MAX_MAX_NEW_TOKENS = 2048
22
  DEFAULT_MAX_NEW_TOKENS = 1024
23
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
24
 
 
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
- model_id = "google/gemma-2-2b-it"
28
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
@@ -33,11 +41,70 @@ model = AutoModelForCausalLM.from_pretrained(
33
  )
34
  model.config.sliding_window = 4096
35
  model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- @app.route("/")
38
- def hello():
39
- return "hello 🤗, Welcome to Sema AI Chat Service."
40
 
 
 
41
  # Flask route to handle incoming chat requests
42
  @app.route('/chat', methods=['POST'])
43
  def chat():
@@ -81,3 +148,4 @@ def generate_response(prompt_input, email, passwd):
81
 
82
  if __name__ == '__main__':
83
  app.run(debug=True)
 
 
12
 
13
  print("Hello welcome to Sema AI", flush=True) # Flush to ensure immediate output
14
 
15
+ @app.route("/")
16
+ def hello():
17
+ return "hello 🤗, Welcome to Sema AI Chat Service."
18
+
19
  # Get Hugging Face credentials from environment variables
20
  email = os.getenv('HF_EMAIL')
21
  password = os.getenv('HF_PASS')
22
  GEMMA_TOKEN = os.getenv("GEMMA_TOKEN")
23
  #print(f"email is {email} and password is {password}", flush=True)
24
 
25
+ if not (email, password,GEMMA_TOKEN):
26
+ print("no dependacies", flush=True)
27
+
28
+ """
29
  MAX_MAX_NEW_TOKENS = 2048
30
  DEFAULT_MAX_NEW_TOKENS = 1024
31
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
 
33
+ model_id = "google/gemma-2-2b-it"
34
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
 
 
36
  tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
37
  model = AutoModelForCausalLM.from_pretrained(
38
  model_id,
 
41
  )
42
  model.config.sliding_window = 4096
43
  model.eval()
44
+ """
45
+
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained(model, token=GEMMA_TOKEN, device=device)
48
+
49
+ quantization_config = GPTQConfig(
50
+ bits=4,
51
+ group_size=128,
52
+ dataset="c4", # the original datasets used in GPTQ paper [‘wikitext2’,‘c4’,‘c4-new’,‘ptb’,‘ptb-new’]
53
+ desc_act=False,
54
+ tokenizer=tokenizer,
55
+ batch_size=1,
56
+ )
57
+ quantized=False
58
+ if quantized:
59
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="google/gemma-2-2b-it",
60
+ token=GEMMA_TOKEN,
61
+ quantization_config=quantization_config,
62
+ device_map=device
63
+ )
64
+ else:
65
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="google/gemma-2-2b-it",
66
+ token=GEMMA_TOKEN,
67
+ torch_dtype=torch.float16,
68
+ device_map=device
69
+ )
70
+
71
+
72
+ app_pipeline = pipeline(
73
+ "text-generation",
74
+ model=model,
75
+ tokenizer=tokenizer
76
+ )
77
+
78
+ @app.route("/generate_text", methods=["POST"])
79
+ def generate_Text():
80
+ data = request.json
81
+ prompt = data.get("prompt", "")
82
+ max_new_tokens = data.get("max_new_tokens", 1000)
83
+ do_sample = data.get("do_sample", True)
84
+ temperature = data.get("temperature", 0.1)
85
+ top_k = data.get("top_k", 50)
86
+ top_p = data.get("top_p", 0.95)
87
+
88
+ tokenized_prompt = app_pipeline.tokenizer.apply_chat_template(
89
+ prompt, tokenize=False, add_generation_prompt=True)
90
+ outputs = app_pipeline(
91
+ tokenized_prompt,
92
+ max_new_tokens=max_new_tokens,
93
+ do_sample=do_sample,
94
+ temperature=temperature,
95
+ top_k=top_k,
96
+ top_p=top_p
97
+ )
98
+
99
+ return jsonify({"response": outputs[0]["generated_text"][len(tokenized_prompt):]})
100
+
101
+
102
+ if __name__ == "__main__":
103
+ app.run(debug=False, port=8888)
104
 
 
 
 
105
 
106
+
107
+ """
108
  # Flask route to handle incoming chat requests
109
  @app.route('/chat', methods=['POST'])
110
  def chat():
 
148
 
149
  if __name__ == '__main__':
150
  app.run(debug=True)
151
+ """