shivrajkarewar commited on
Commit
dbbc9f9
·
verified ·
1 Parent(s): a8afb63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -97
app.py CHANGED
@@ -1,101 +1,107 @@
 
 
1
  import gradio as gr
2
- from transformers import (
3
- AutoModelForCausalLM,
4
- AutoTokenizer,
5
- TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  )
7
- from threading import Thread
8
-
9
- # Configuration
10
- MODEL_NAME = "deepseek-ai/DeepSeek-R1"
11
- DEFAULT_MAX_NEW_TOKENS = 512
12
-
13
- # Load model and tokenizer WITH TRUSTED CODE
14
- try:
15
- tokenizer = AutoTokenizer.from_pretrained(
16
- MODEL_NAME,
17
- trust_remote_code=True # <-- ADDED HERE
18
- )
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
21
- device_map="auto",
22
- torch_dtype="auto",
23
- trust_remote_code=True,
24
- use_flash_attention_2=True, # <-- Add this line
25
- # load_in_4bit=True
26
- )
27
- except Exception as e:
28
- raise gr.Error(f"Error loading model: {str(e)}")
29
-
30
-
31
- def generate_text(prompt, max_new_tokens=DEFAULT_MAX_NEW_TOKENS, temperature=0.7, top_p=0.9):
32
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
-
34
- # Streamer for real-time output
35
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
36
-
37
- generation_kwargs = dict(
38
- **inputs,
39
- streamer=streamer,
40
- max_new_tokens=max_new_tokens,
41
- temperature=temperature,
42
- top_p=top_p,
43
- do_sample=True
44
- )
45
-
46
- # Start generation in a thread
47
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
48
- thread.start()
49
-
50
- # Yield generated text
51
- generated_text = ""
52
- for new_text in streamer:
53
- generated_text += new_text
54
- yield generated_text
55
-
56
- # Gradio interface
57
- with gr.Blocks() as demo:
58
- gr.Markdown("# DeepSeek-R1 Demo")
59
-
60
- with gr.Row():
61
- input_text = gr.Textbox(
62
- label="Input Prompt",
63
- placeholder="Enter your prompt here...",
64
- lines=5
65
- )
66
- output_text = gr.Textbox(
67
- label="Generated Response",
68
- interactive=False,
69
- lines=10
70
- )
71
-
72
- with gr.Accordion("Advanced Settings", open=False):
73
- max_tokens = gr.Slider(
74
- minimum=64,
75
- maximum=2048,
76
- value=DEFAULT_MAX_NEW_TOKENS,
77
- label="Max New Tokens"
78
- )
79
- temperature = gr.Slider(
80
- minimum=0.1,
81
- maximum=1.5,
82
- value=0.7,
83
- label="Temperature"
84
- )
85
- top_p = gr.Slider(
86
- minimum=0.1,
87
- maximum=1.0,
88
- value=0.9,
89
- label="Top-p"
90
- )
91
-
92
- submit_btn = gr.Button("Generate")
93
- submit_btn.click(
94
- fn=generate_text,
95
- inputs=[input_text, max_tokens, temperature, top_p],
96
- outputs=output_text,
97
- api_name="generate"
98
- )
99
 
 
100
  if __name__ == "__main__":
101
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
+ import requests
3
  import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ from ase.build import bulk
6
+ from ase.visualize.plot import plot_atoms
7
+ from io import BytesIO
8
+
9
+ # Retrieve the API key from the environment variable
10
+ groq_api_key = os.getenv("GROQ_API_KEY")
11
+
12
+ if not groq_api_key:
13
+ raise ValueError("GROQ_API_KEY is missing! Set it in the Hugging Face Spaces 'Secrets'.")
14
+
15
+ # Define the API endpoint and headers
16
+ url = "https://api.groq.com/openai/v1/chat/completions"
17
+ headers = {"Authorization": f"Bearer {groq_api_key}"}
18
+
19
+ # Helper function to generate structure visualization
20
+ def visualize_structure(material: str):
21
+ """
22
+ Generate an atomic structure visualization for a given material.
23
+
24
+ Parameters:
25
+ - material (str): Chemical symbol of the material (e.g., 'Fe' for iron).
26
+
27
+ Returns:
28
+ - BytesIO object containing the image data if successful, None otherwise.
29
+ """
30
+ try:
31
+ # Create a bulk structure; adjust 'crystalstructure' as needed
32
+ atoms = bulk(material, crystalstructure='fcc') # Default to face-centered cubic
33
+ except Exception as e:
34
+ print(f"Error creating structure for {material}: {e}")
35
+ return None
36
+
37
+ # Plot the atomic structure
38
+ fig, ax = plt.subplots(figsize=(4, 4))
39
+ plot_atoms(atoms, ax, radii=0.3, rotation=('45x,45y,0z'), show_unit_cell=2)
40
+ buf = BytesIO()
41
+ plt.tight_layout()
42
+ plt.savefig(buf, format="png")
43
+ buf.seek(0)
44
+ return buf
45
+
46
+ # Function to interact with Groq API and return 3 best materials with visuals
47
+ def recommend_materials(user_input):
48
+ """
49
+ Recommend three materials for a given application and provide their visualizations.
50
+
51
+ Parameters:
52
+ - user_input (str): Description of the application.
53
+
54
+ Returns:
55
+ - Tuple containing:
56
+ - Recommendations and properties as a string.
57
+ - List of BytesIO objects containing images of the atomic structures.
58
+ """
59
+ prompt = f"You are a materials science expert. Recommend the 3 best materials for the following application: '{user_input}'. " \
60
+ f"For each material, list key properties (e.g., mechanical, thermal, chemical)."
61
+
62
+ body = {
63
+ "model": "llama-3.1-8b-instant",
64
+ "messages": [{"role": "user", "content": prompt}]
65
+ }
66
+
67
+ response = requests.post(url, headers=headers, json=body)
68
+
69
+ if response.status_code != 200:
70
+ return f"Error: {response.json()}", [None, None, None]
71
+
72
+ reply = response.json()['choices'][0]['message']['content']
73
+
74
+ # Extract material names from the response
75
+ lines = reply.splitlines()
76
+ materials = []
77
+ for line in lines:
78
+ if line.strip().startswith(("1.", "2.", "3.")):
79
+ words = line.split()
80
+ if len(words) > 1:
81
+ material = words[1].strip(",.")
82
+ materials.append(material)
83
+ if len(materials) == 3:
84
+ break
85
+
86
+ # Generate visualizations for each material
87
+ images = [visualize_structure(mat) for mat in materials]
88
+
89
+ return reply, images
90
+
91
+ # Gradio Interface
92
+ interface = gr.Interface(
93
+ fn=recommend_materials,
94
+ inputs=gr.Textbox(lines=2, placeholder="e.g., High strength lightweight material for aerospace"),
95
+ outputs=[
96
+ gr.Textbox(label="Recommended Materials & Properties"),
97
+ gr.Image(label="Atomic Structure of Material 1"),
98
+ gr.Image(label="Atomic Structure of Material 2"),
99
+ gr.Image(label="Atomic Structure of Material 3")
100
+ ],
101
+ title="Materials Science Expert",
102
+ description="Ask for the best materials for your application. Get 3 top recommendations with key properties and atomic structure visualizations."
103
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Launch Gradio app
106
  if __name__ == "__main__":
107
+ interface.launch()