Tri4 commited on
Commit
d36e009
·
verified ·
1 Parent(s): 331935a

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +80 -0
main.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from huggingface_hub import InferenceClient
3
+
4
+ # Initialize Flask app
5
+ app = Flask(__name__)
6
+
7
+ print("\nHello welcome to Sema AI\n", flush=True) # Flush to ensure immediate output
8
+
9
+ @app.route("/")
10
+ def hello():
11
+ return "hello 🤗, Welcome to Sema AI Chat Service."
12
+
13
+ # Initialize InferenceClient
14
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
15
+
16
+ def format_prompt(message, history):
17
+ prompt = "<s>"
18
+ for user_prompt, bot_response in history:
19
+ prompt += f"[INST] {user_prompt} [/INST]"
20
+ prompt += f" {bot_response}</s> "
21
+ prompt += f"[INST] {message} [/INST]"
22
+ return prompt
23
+
24
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
25
+ temperature = float(temperature)
26
+ if temperature < 1e-2:
27
+ temperature = 1e-2
28
+ top_p = float(top_p)
29
+
30
+ generate_kwargs = dict(
31
+ temperature=temperature,
32
+ max_new_tokens=max_new_tokens,
33
+ top_p=top_p,
34
+ repetition_penalty=repetition_penalty,
35
+ do_sample=True,
36
+ seed=42,
37
+ )
38
+
39
+ formatted_prompt = format_prompt(prompt, history)
40
+
41
+ # Get response from Mistral model
42
+ response = client.text_generation(
43
+ formatted_prompt,
44
+ **generate_kwargs,
45
+ stream=True,
46
+ details=True,
47
+ return_full_text=False
48
+ )
49
+
50
+ output = ""
51
+ for token in response:
52
+ output += token.token.text
53
+ return output
54
+
55
+ @app.route("/generate", methods=["POST"])
56
+ def generate_text():
57
+ data = request.json
58
+ prompt = data.get("prompt", "")
59
+ history = data.get("history", [])
60
+ temperature = data.get("temperature", 0.9)
61
+ max_new_tokens = data.get("max_new_tokens", 256)
62
+ top_p = data.get("top_p", 0.95)
63
+ repetition_penalty = data.get("repetition_penalty", 1.0)
64
+ print(f"{prompt}: \n")
65
+
66
+ try:
67
+ response_text = generate(
68
+ prompt,
69
+ history,
70
+ temperature=temperature,
71
+ max_new_tokens=max_new_tokens,
72
+ top_p=top_p,
73
+ repetition_penalty=repetition_penalty
74
+ )
75
+ return jsonify({"response": response_text})
76
+ except Exception as e:
77
+ return jsonify({"error": str(e)}), 500
78
+
79
+ if __name__ == "__main__":
80
+ app.run(debug=True, port=5000)