import networkx as nx
import matplotlib.pyplot as plt
# Create an empty graph
G = nx.Graph()
# Define the number of nodes in each layer
input_nodes = 3
hidden_nodes = 4
# Add nodes to the graph
for i in range(input_nodes):
G.add_node("Input {}".format(i+1))
for i in range(hidden_nodes):
G.add_node("Hidden {}".format(i+1))
# Add edges between nodes
for i in range(input_nodes):
for j in range(hidden_nodes):
G.add_edge("Input {}".format(i+1), "Hidden {}".format(j+1))
# Position nodes in the graph
pos = nx.spring_layout(G)
# Draw the graph
nx.draw(G, pos, with_labels=True, node_size=800, node_color='lightblue', font_size=10, font_weight='bold', edge_color='gray')
# Display the graph
plt.title("Neural Network Graph")
plt.axis('off')
plt.show()