File size: 9,810 Bytes
31e2261
 
 
 
 
 
 
 
 
 
7228d5e
31e2261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8520ad
82e1cc2
 
31e2261
 
 
 
 
ef5351a
 
 
 
 
 
31e2261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef5351a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31e2261
99861f6
 
 
 
f049afb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31e2261
 
b8790bd
31e2261
 
 
 
d8e35bf
 
 
 
31e2261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd438ab
 
82e1cc2
fd438ab
f049afb
31e2261
 
7228d5e
 
 
31e2261
 
7228d5e
82e1cc2
 
 
31e2261
 
7228d5e
 
 
 
 
31e2261
359841d
31e2261
 
359841d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import spaces
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import json

title = "# 🙋🏻‍♂️Welcome to 🌟Tonic's 🌕💉👨🏻‍🔬Moonshot Math"

description = """
     **🌕💉👨🏻‍🔬AI-MO/Kimina-Prover-Distill-8B is a theorem proving model developed by Project Numina and Kimi teams, focusing on competition style problem solving capabilities in Lean 4. It is a distillation of AI-MO/Kimina-Prover-72B, a model trained via large scale reinforcement learning. It achieves 77.86% accuracy with Pass@32 on MiniF2F-test.\
- [Kimina-Prover-Preview GitHub](https://github.com/MoonshotAI/Kimina-Prover-Preview)\
- [Hugging Face: AI-MO/Kimina-Prover-72B](https://huggingface.co/AI-MO/Kimina-Prover-72B)\
- [Kimina Prover blog](https://huggingface.co/blog/AI-MO/kimina-prover)\
- [unimath dataset](https://huggingface.co/datasets/introspector/unimath)\
"""

citation = """> **Citation:**
> ```
> @article{kimina_prover_2025,
>   title = {Kimina-Prover Preview: Towards Large Formal Reasoning Models with Reinforcement Learning},
>   author = {Wang, Haiming and Unsal, Mert and ...},
>   year = {2025},
>   url = {http://arxiv.org/abs/2504.11354},
> }
> ```
"""


joinus = """
## Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [MultiTonic](https://github.com/MultiTonic)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""

SYSTEM_PROMPT = "You are an expert in mathematics and Lean 4."


LEAN4_DEFAULT_HEADER = (
    "import Mathlib\n"
    "import Aesop\n\n"
    "set_option maxHeartbeats 0\n\n"
    "open BigOperators Real Nat Topology Rat\n"
)

unimath1 = """Goal:
  X : UU
  Y : UU
  P : UU
  xp : (X → P) → P
  yp : (Y → P) → P
  X0 : X × Y → P
  x : X
  ============================
   (Y → P)"""

unimath2 = """Goal:
    R : ring  M : module R
  ============================
   (islinear (idfun M))"""

unimath3 = """Goal:
    X : UU  i : nat  b : hProptoType (i < S i)  x : Vector X (S i)  r : i = i
  ============================
   (pr1 lastelement = pr1 (i,, b))"""

unimath4 = """Goal:
    X : dcpo  CX : continuous_dcpo_struct X  x : pr1hSet X  y : pr1hSet X
  ============================
   (x ⊑ y ≃ (∀ i : approximating_family CX x, approximating_family CX x i ⊑ y))"""

additional_info_prompt = "/-Explain using mathematics-/\n"

def build_formal_block(formal_statement, informal_prefix=""):
    return (
        f"{LEAN4_DEFAULT_HEADER}\n"
        f"{informal_prefix}\n"
        f"{formal_statement}"
    )

def extract_lean4_code(text):
    code_block = re.search(r"```lean4(.*?)(```|$)", text, re.DOTALL)
    if code_block:
        code = code_block.group(1)
        lines = [line for line in code.split('\n') if line.strip()]
        return '\n'.join(lines)
    return text.strip()

examples = [
    [unimath1, additional_info_prompt, 1234],
    [unimath2, additional_info_prompt, 1234],
    [unimath3, additional_info_prompt, 1234],
    [unimath4, additional_info_prompt, 1234],
    [
        '''import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- Let $a_1, a_2,\cdots, a_n$ be real constants, $x$ a real variable, and $f(x)=\\cos(a_1+x)+\\frac{1}{2}\\cos(a_2+x)+\\frac{1}{4}\\cos(a_3+x)+\\cdots+\\frac{1}{2^{n-1}}\\cos(a_n+x).$ Given that $f(x_1)=f(x_2)=0,$ prove that $x_2-x_1=m\\pi$ for some integer $m.$-/\ntheorem imo_1969_p2 (m n : \\R) (k : \\N) (a : \\N \\rightarrow \\R) (y : \\R \\rightarrow \\R) (h₀ : 0 < k)\n(h₁ : \\forall x, y x = \\sum i in Finset.range k, Real.cos (a i + x) / 2 ^ i) (h₂ : y m = 0)\n(h₃ : y n = 0) : \\exists t : \\Z, m - n = t * Real.pi := by''',
        "/-- Let $a_1, a_2,\\cdots, a_n$ be real constants, $x$ a real variable, and $f(x)=\\cos(a_1+x)+\\frac{1}{2}\\cos(a_2+x)+\\frac{1}{4}\\cos(a_3+x)+\\cdots+\\frac{1}{2^{n-1}}\\cos(a_n+x).$ Given that $f(x_1)=f(x_2)=0,$ prove that $x_2-x_1=m\\pi$ for some integer $m.$-/",
        2500
    ],
    [
        '''import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- Suppose that $h(x)=f^{-1}(x)$. If $h(2)=10$, $h(10)=1$ and $h(1)=2$, what is $f(f(10))$? Show that it is 1.-/\ntheorem mathd_algebra_209 (σ : Equiv \\R \\R) (h₀ : σ.2 2 = 10) (h₁ : σ.2 10 = 1) (h₂ : σ.2 1 = 2) :\nσ.1 (σ.1 10) = 1 := by''',
        "/-- Suppose that $h(x)=f^{-1}(x)$. If $h(2)=10$, $h(10)=1$ and $h(1)=2$, what is $f(f(10))$? Show that it is 1.-/",
        2500
    ],
    [
        '''import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n/-- At which point do the lines $s=9-2t$ and $t=3s+1$ intersect? Give your answer as an ordered pair in the form $(s, t).$ Show that it is (1,4).-//\ntheorem mathd_algebra_44 (s t : \\R) (h₀ : s = 9 - 2 * t) (h₁ : t = 3 * s + 1) : s = 1 \\wedge t = 4 := by''',
        "/-- At which point do the lines $s=9-2t$ and $t=3s+1$ intersect? Give your answer as an ordered pair in the form $(s, t).$ Show that it is (1,4).-/",
        2500
    ],
]

model_name = "AI-MO/Kimina-Prover-Distill-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)

model.generation_config = GenerationConfig.from_pretrained(model_name)
if isinstance(model.generation_config.eos_token_id, list):
    model.generation_config.pad_token_id = model.generation_config.eos_token_id[0]
else:
    model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.do_sample = True
model.generation_config.temperature = 0.6
model.generation_config.top_p = 0.95

def init_chat(formal_statement, informal_prefix):
    user_prompt = (
        "Think about and solve the following problem step by step in Lean 4.\n"
        "# Problem: Provide a formal proof for the following statement.\n"
        f"# Formal statement:\n```lean4\n{build_formal_block(formal_statement, informal_prefix)}\n```\n"
    )
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt}
    ]

@spaces.GPU
def chat_handler(user_message, informal_prefix, max_tokens, chat_history):
    if not chat_history or len(chat_history) < 2:
        chat_history = init_chat(user_message, informal_prefix)
        display_history = [("user", user_message)]
    else:
        chat_history.append({"role": "user", "content": user_message})
        display_history = []
        for msg in chat_history:
            if msg["role"] == "user":
                display_history.append(("user", msg["content"]))
            elif msg["role"] == "assistant":
                display_history.append(("assistant", msg["content"]))
    prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    attention_mask = torch.ones_like(input_ids)
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_tokens + input_ids.shape[1],
        pad_token_id=model.generation_config.pad_token_id,
        temperature=model.generation_config.temperature,
        top_p=model.generation_config.top_p,
    )
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    new_response = result[len(prompt):].strip()
    chat_history.append({"role": "assistant", "content": new_response})
    display_history.append(("assistant", new_response))
    code = extract_lean4_code(new_response)
    output_data = {
        "model_input": prompt,
        "model_output": result,
        "lean4_code": code,
        "chat_history": chat_history
    }
    return display_history, json.dumps(output_data, indent=2), code, chat_history

def main():
    with gr.Blocks() as demo:
        gr.Markdown("""# 🙋🏻‍♂️Welcome to 🌟Tonic's 🌕💉👨🏻‍🔬Moonshot Math""")
        with gr.Row():
            with gr.Column():
                gr.Markdown(description)
            with gr.Column():
                gr.Markdown(joinus)
        with gr.Row():
            with gr.Column():
                user_input = gr.Textbox(label="👨🏻‍💻Your message or formal statement", lines=4)
                informal = gr.Textbox(value=additional_info_prompt, label="💁🏻‍♂️Optional informal prefix")
                max_tokens = gr.Slider(minimum=150, maximum=4096, value=2500, label="🪙Max Tokens")
                submit = gr.Button("Send")
            with gr.Column():
                chat = gr.Chatbot(label="🌕💉👨🏻‍🔬Kimina Prover 8B")
                with gr.Accordion("Complete Output", open=False):
                    json_out = gr.JSON(label="Full Output")
                    code_out = gr.Code(label="Extracted Lean4 Code", language="python")
        state = gr.State([])
        submit.click(chat_handler, [user_input, informal, max_tokens, state], [chat, json_out, code_out, state])
        gr.Examples(
                    examples=examples,
                    inputs=[user_input, informal, max_tokens],
                    label="🤦🏻‍♂️Example Problems"
                )
        gr.Markdown(citation)
    demo.launch(ssr_mode=False, mcp_server=True)

if __name__ == "__main__":
    main()