Source code for littleballoffur.exploration_sampling.forestfiresampler

import random
import numpy as np
import networkx as nx
import networkit as nk
from typing import Union
from collections import deque
from littleballoffur.sampler import Sampler


NKGraph = type(nk.graph.Graph())
NXGraph = nx.classes.graph.Graph


[docs]class ForestFireSampler(Sampler): r"""An implementation of forest fire sampling. The procedure is a stochastic snowball sampling method where the expansion is proportional to the burning probability. `"For details about the algorithm see this paper." <https://cs.stanford.edu/people/jure/pubs/sampling-kdd06.pdf>`_ Args: number_of_nodes (int): Number of sampled nodes. Default is 100. p (float): Burning probability. Default is 0.4. seed (int): Random seed. Default is 42. """ def __init__( self, number_of_nodes: int = 100, p: float = 0.4, seed: int = 42, max_visited_nodes_backlog: int = 100, restart_hop_size: int = 10, ): self.number_of_nodes = number_of_nodes self.p = p self.seed = seed self._set_seed() self.restart_hop_size = restart_hop_size self.max_visited_nodes_backlog = max_visited_nodes_backlog def _create_node_sets(self, graph): """ Create a starting set of nodes. """ self._sampled_nodes = set() self._set_of_nodes = set(range(self.backend.get_number_of_nodes(graph))) self._visited_nodes = deque(maxlen=self.max_visited_nodes_backlog) def _start_a_fire(self, graph): """ Starting a forest fire from a single node. """ remaining_nodes = list(self._set_of_nodes.difference(self._sampled_nodes)) seed_node = random.choice(remaining_nodes) self._sampled_nodes.add(seed_node) node_queue = deque([seed_node]) while len(self._sampled_nodes) < self.number_of_nodes: if len(node_queue) == 0: node_queue = deque( [ self._visited_nodes.popleft() for k in range( min(self.restart_hop_size, len(self._visited_nodes)) ) ] ) if len(node_queue) == 0: print( "Warning: could not collect the required number of nodes. The fire could not find enough nodes to burn." ) break top_node = node_queue.popleft() self._sampled_nodes.add(top_node) neighbors = set(self.backend.get_neighbors(graph, top_node)) unvisited_neighbors = neighbors.difference(self._sampled_nodes) score = np.random.geometric(self.p) count = min(len(unvisited_neighbors), score) burned_neighbors = random.sample(unvisited_neighbors, count) self._visited_nodes.extendleft( unvisited_neighbors.difference(set(burned_neighbors)) ) for neighbor in burned_neighbors: if len(self._sampled_nodes) >= self.number_of_nodes: break node_queue.extend([neighbor])
[docs] def sample(self, graph: Union[NXGraph, NKGraph]) -> Union[NXGraph, NKGraph]: """ Sampling nodes iteratively with a forest fire sampler. Arg types: * **graph** *(NetworkX or NetworKit graph)* - The graph to be sampled from. Return types: * **new_graph** *(NetworkX or NetworKit graph)* - The graph of sampled nodes. """ self._deploy_backend(graph) self._check_number_of_nodes(graph) self._create_node_sets(graph) while len(self._sampled_nodes) < self.number_of_nodes: self._start_a_fire(graph) new_graph = self.backend.get_subgraph(graph, self._sampled_nodes) return new_graph