Tri4 commited on
Commit
d8c74e9
·
verified ·
1 Parent(s): 7ae3c98

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +86 -0
main.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from huggingface_hub import InferenceClient
3
+
4
+ # Initialize Flask app
5
+ app = Flask(__name__)
6
+
7
+ print("\nHello welcome to Sema AI\n", flush=True) # Flush to ensure immediate output
8
+
9
+ # Initialize InferenceClient
10
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
11
+
12
+ def format_prompt(message, history):
13
+ prompt = "<s>"
14
+ for user_prompt, bot_response in history:
15
+ prompt += f"[INST] {user_prompt} [/INST]"
16
+ prompt += f" {bot_response}</s> "
17
+ prompt += f"[INST] {message} [/INST]"
18
+ return prompt
19
+
20
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
21
+ # Print user prompt
22
+ print(f"\nUser: {prompt}\n", flush=True)
23
+
24
+ temperature = float(temperature)
25
+ if temperature < 1e-2:
26
+ temperature = 1e-2
27
+ top_p = float(top_p)
28
+
29
+ generate_kwargs = dict(
30
+ temperature=temperature,
31
+ max_new_tokens=max_new_tokens,
32
+ top_p=top_p,
33
+ repetition_penalty=repetition_penalty,
34
+ do_sample=True,
35
+ seed=42,
36
+ )
37
+
38
+ formatted_prompt = format_prompt(prompt, history)
39
+
40
+ # Get response from Mistral model
41
+ response = client.text_generation(
42
+ formatted_prompt,
43
+ **generate_kwargs,
44
+ stream=True,
45
+ details=True,
46
+ return_full_text=False
47
+ )
48
+
49
+ output = ""
50
+ for token in response:
51
+ if hasattr(token, 'token') and hasattr(token.token, 'text'):
52
+ output += token.token.text
53
+ else:
54
+ print(f"Unexpected token structure: {token}", flush=True)
55
+
56
+ # Print AI response
57
+ print(f"\nSema AI: {output}\n", flush=True)
58
+ return output
59
+
60
+ @app.route("/generate", methods=["POST"])
61
+ def generate_text():
62
+ data = request.json
63
+ prompt = data.get("prompt", "")
64
+ history = data.get("history", [])
65
+ temperature = data.get("temperature", 0.9)
66
+ max_new_tokens = data.get("max_new_tokens", 256)
67
+ top_p = data.get("top_p", 0.95)
68
+ repetition_penalty = data.get("repetition_penalty", 1.0)
69
+
70
+ try:
71
+ response_text = generate(
72
+ prompt,
73
+ history,
74
+ temperature=temperature,
75
+ max_new_tokens=max_new_tokens,
76
+ top_p=top_p,
77
+ repetition_penalty=repetition_penalty
78
+ )
79
+ return jsonify({"response": response_text})
80
+ except Exception as e:
81
+ # Print error
82
+ print(f"Error: {str(e)}", flush=True)
83
+ return jsonify({"error": str(e)}), 500
84
+
85
+ if __name__ == "__main__":
86
+ app.run(debug=True, port=5000)