root/trunk/core/src/shooting_star.c

Revision 338, 14.7 KB (checked in by anton, 14 months ago)

See #160

Line 
1/*
2 * Shooting* Shortest path algorithm for PostgreSQL
3 *
4 * Copyright (c) 2007 Anton A. Patrushev, Orkney, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "postgres.h"
23#include "executor/spi.h"
24#include "funcapi.h"
25#include "catalog/pg_type.h"
26
27#include <stdio.h>
28#include <stdlib.h>
29#include <search.h>
30
31#include <string.h>
32#include <time.h>
33
34#include "shooting_star.h"
35
36//-------------------------------------------------------------------------
37
38Datum shortest_path_shooting_star(PG_FUNCTION_ARGS);
39
40#undef DEBUG
41//#define DEBUG 1
42
43#ifdef DEBUG
44#define DBG(format, arg...)                     \
45    elog(NOTICE, format , ## arg)
46#else
47#define DBG(format, arg...) do { ; } while (0)
48#endif
49
50// The number of tuples to fetch from the SPI cursor at each iteration
51#define TUPLIMIT 1000
52
53static char *
54text2char(text *in)
55{
56  char *out = palloc(VARSIZE(in));
57
58  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
59  out[VARSIZE(in) - VARHDRSZ] = '\0';
60  return out;
61}
62
63static int
64finish(int code, int ret)
65{
66  code = SPI_finish();
67  if (code  != SPI_OK_FINISH )
68    {
69      elog(ERROR,"couldn't disconnect from SPI");
70      return -1 ;
71    }
72 
73  return ret;
74}
75 
76typedef struct edge_shooting_star_columns
77{
78  int id;
79  int source;
80  int target;
81  int cost;
82  int reverse_cost;
83  int s_x;
84  int s_y;
85  int t_x;
86  int t_y;
87  int to_cost;//cost of transit to adjacent edge
88  int rule;
89} edge_shooting_star_columns_t;
90
91static int
92fetch_edge_shooting_star_columns(SPITupleTable *tuptable,
93                         edge_shooting_star_columns_t *edge_columns,
94                         bool has_reverse_cost)
95{
96  edge_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "id");
97  edge_columns->source = SPI_fnumber(SPI_tuptable->tupdesc, "source");
98  edge_columns->target = SPI_fnumber(SPI_tuptable->tupdesc, "target");
99  edge_columns->cost = SPI_fnumber(SPI_tuptable->tupdesc, "cost");
100  if (edge_columns->id == SPI_ERROR_NOATTRIBUTE ||
101      edge_columns->source == SPI_ERROR_NOATTRIBUTE ||
102      edge_columns->target == SPI_ERROR_NOATTRIBUTE ||
103      edge_columns->cost == SPI_ERROR_NOATTRIBUTE)
104    {
105      elog(ERROR, "Error, query must return columns "
106           "'id', 'source', 'target' and 'cost'");
107      return -1;
108    }
109
110  if (SPI_gettypeid(SPI_tuptable->tupdesc,
111                    edge_columns->source) != INT4OID ||
112      SPI_gettypeid(SPI_tuptable->tupdesc,
113                    edge_columns->target) != INT4OID ||
114      SPI_gettypeid(SPI_tuptable->tupdesc, edge_columns->cost) != FLOAT8OID)
115    {
116      elog(ERROR, "Error, columns 'source', 'target' must be of type int4, "
117           "'cost' must be of type float8");
118      return -1;
119    }
120
121  DBG("columns: id %i source %i target %i cost %i",
122      edge_columns->id, edge_columns->source,
123      edge_columns->target, edge_columns->cost);
124
125  if (has_reverse_cost)
126    {
127      edge_columns->reverse_cost = SPI_fnumber(SPI_tuptable->tupdesc,
128                                               "reverse_cost");
129
130      if (edge_columns->reverse_cost == SPI_ERROR_NOATTRIBUTE)
131        {
132          elog(ERROR, "Error, reverse_cost is used, but query did't return "
133               "'reverse_cost' column");
134          return -1;
135        }
136
137      if (SPI_gettypeid(SPI_tuptable->tupdesc,
138                        edge_columns->reverse_cost) != FLOAT8OID)
139        {
140          elog(ERROR, "Error, columns 'reverse_cost' must be of type float8");
141          return -1;
142        }
143
144      DBG("columns: reverse_cost cost %i", edge_columns->reverse_cost);
145    }
146
147  edge_columns->s_x = SPI_fnumber(SPI_tuptable->tupdesc, "x1");
148  edge_columns->s_y = SPI_fnumber(SPI_tuptable->tupdesc, "y1");
149  edge_columns->t_x = SPI_fnumber(SPI_tuptable->tupdesc, "x2");
150  edge_columns->t_y = SPI_fnumber(SPI_tuptable->tupdesc, "y2");
151
152  if (edge_columns->s_x == SPI_ERROR_NOATTRIBUTE ||
153      edge_columns->s_y == SPI_ERROR_NOATTRIBUTE ||
154      edge_columns->t_x == SPI_ERROR_NOATTRIBUTE ||
155      edge_columns->t_y == SPI_ERROR_NOATTRIBUTE)
156    {
157      elog(ERROR, "Error, query must return columns "
158           "'x1', 'x2', 'y1' and 'y2'");
159      return -1;
160    }
161
162  DBG("columns: x1 %i y1 %i x2 %i y2 %i",
163      edge_columns->s_x, edge_columns->s_y,
164      edge_columns->t_x,edge_columns->t_y);
165   
166
167  edge_columns->to_cost = SPI_fnumber(SPI_tuptable->tupdesc, "to_cost");
168  edge_columns->rule = SPI_fnumber(SPI_tuptable->tupdesc, "rule");
169
170  if (edge_columns->to_cost == SPI_ERROR_NOATTRIBUTE ||
171      edge_columns->rule == SPI_ERROR_NOATTRIBUTE)
172    {
173      elog(ERROR, "Error, query must return columns "
174           "'to_cost' and 'rule'");
175      return -1;
176    }
177
178  return 0;
179}
180
181//edges should be ordered by id or else we have to search for
182//existing edges every time we want to add adjacent edge
183static void
184fetch_edge_shooting_star(HeapTuple *tuple, TupleDesc *tupdesc,
185                 edge_shooting_star_columns_t *edge_columns,
186                 edge_shooting_star_t *target_edge)
187{
188  Datum binval;
189  bool isnull;
190  int t;
191
192  for(t=0; t<MAX_RULE_LENGTH;++t)
193    target_edge->rule[t] = -1;
194   
195  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->id, &isnull);
196  if (isnull)
197    elog(ERROR, "id contains a null value");
198  target_edge->id = DatumGetInt32(binval);
199 
200  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->source, &isnull);
201  if (isnull)
202    elog(ERROR, "source contains a null value");
203  target_edge->source = DatumGetInt32(binval);
204
205  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->target, &isnull);
206  if (isnull)
207    elog(ERROR, "target contains a null value");
208  target_edge->target = DatumGetInt32(binval);
209
210  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->cost, &isnull);
211  if (isnull)
212    elog(ERROR, "cost contains a null value");
213  target_edge->cost = DatumGetFloat8(binval);
214
215  if (edge_columns->reverse_cost != -1)
216    {
217      binval = SPI_getbinval(*tuple, *tupdesc,
218                             edge_columns->reverse_cost, &isnull);
219      if (isnull)
220        elog(ERROR, "reverse_cost contains a null value");
221      target_edge->reverse_cost =  DatumGetFloat8(binval);
222    }
223
224  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->s_x, &isnull);
225  if (isnull)
226    elog(ERROR, "source x contains a null value");
227  target_edge->s_x = DatumGetFloat8(binval);
228
229  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->s_y, &isnull);
230  if (isnull)
231    elog(ERROR, "source y contains a null value");
232  target_edge->s_y = DatumGetFloat8(binval);
233 
234  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->t_x, &isnull);
235  if (isnull)
236    elog(ERROR, "target x contains a null value");
237  target_edge->t_x = DatumGetFloat8(binval);
238 
239  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->t_y, &isnull);
240  if (isnull)
241    elog(ERROR, "target y contains a null value");
242  target_edge->t_y = DatumGetFloat8(binval);
243
244  binval = SPI_getbinval(*tuple, *tupdesc, edge_columns->to_cost, &isnull);
245  if (isnull)
246    target_edge->to_cost = 0;
247   
248  else
249    target_edge->to_cost = DatumGetFloat8(binval);
250
251  char *str = DatumGetCString(SPI_getvalue(*tuple, *tupdesc, edge_columns->rule));
252
253  if(str!=NULL)
254  {
255    char* pch = NULL;
256    int ci = MAX_RULE_LENGTH;
257
258    pch = (char *)strtok (str," ,");
259 
260    while (pch != NULL)
261    {
262      --ci;
263      target_edge->rule[ci] = atoi(pch);
264      pch = (char *)strtok (NULL, " ,");
265    }
266  }
267}
268
269
270static int compute_shortest_path_shooting_star(char* sql, int source_edge_id,
271                                       int target_edge_id, bool directed,
272                                       bool has_reverse_cost,
273                                       path_element_t **path, int *path_count)
274{
275 
276  int SPIcode;
277  void *SPIplan;
278  Portal SPIportal;
279  bool moredata = TRUE;
280  int ntuples;
281  edge_shooting_star_t *edges = NULL;
282  int total_tuples = 0;
283 
284//  int v_max_id=0;
285//  int v_min_id=INT_MAX; 
286  int e_max_id=0;
287  int e_min_id=INT_MAX; 
288   
289  edge_shooting_star_columns_t edge_columns = {id: -1, source: -1, target: -1,
290                                       cost: -1, reverse_cost: -1,
291                                       s_x: -1, s_y: -1, t_x: -1, t_y: -1,
292                                       to_cost: -1, rule: -1};
293  char *err_msg;
294  int ret = -1;
295  register int z, t;
296 
297  int s_count=0;
298  int t_count=0;
299 
300  DBG("start shortest_path_shooting_star\n");
301       
302  SPIcode = SPI_connect();
303  if (SPIcode  != SPI_OK_CONNECT)
304    {
305      elog(ERROR, "shortest_path_shooting_star: couldn't open a connection to SPI");
306      return -1;
307    }
308
309  SPIplan = SPI_prepare(sql, 0, NULL);
310  if (SPIplan  == NULL)
311    {
312      elog(ERROR, "shortest_path_shooting_star: couldn't create query plan via SPI");
313      return -1;
314    }
315
316  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
317    {
318      elog(ERROR, "shortest_path_shooting_star: SPI_cursor_open('%s') returns NULL",
319           sql);
320      return -1;
321    }
322
323  while (moredata == TRUE)
324    {
325      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
326
327      if (edge_columns.id == -1)
328        {
329          if (fetch_edge_shooting_star_columns(SPI_tuptable, &edge_columns,
330                                       has_reverse_cost) == -1)
331            return finish(SPIcode, ret);
332        }
333       
334        //DBG("***%i***", ret);
335
336      ntuples = SPI_processed;
337      total_tuples += ntuples;
338
339      if (!edges)
340        edges = palloc(total_tuples * sizeof(edge_shooting_star_t));
341      else
342        edges = repalloc(edges, total_tuples * sizeof(edge_shooting_star_t));
343
344      if (edges == NULL)
345        {
346          elog(ERROR, "Out of memory");
347          return finish(SPIcode, ret);
348        }
349
350      if (ntuples > 0)
351        {
352          int t;
353          SPITupleTable *tuptable = SPI_tuptable;
354          TupleDesc tupdesc = SPI_tuptable->tupdesc;
355         
356          for (t = 0; t < ntuples; t++)
357            {
358              HeapTuple tuple = tuptable->vals[t];
359              fetch_edge_shooting_star(&tuple, &tupdesc, &edge_columns,
360                               &edges[total_tuples - ntuples + t]);
361            }
362          SPI_freetuptable(tuptable);
363        }
364      else
365        {
366          moredata = FALSE;
367        }
368    }
369   
370     
371  DBG("Total %i tuples", total_tuples);
372
373   
374
375  for(z=0; z<total_tuples; z++)
376  {
377    if(edges[z].id<e_min_id)
378      e_min_id=edges[z].id;
379
380    if(edges[z].id>e_max_id)
381      e_max_id=edges[z].id;
382
383  }
384
385    DBG("E : %i <-> %i", e_min_id, e_max_id);
386
387  for(z=0; z<total_tuples; ++z)
388  {
389
390    //check if edges[] contains source and target
391    if(edges[z].id == source_edge_id ||
392       edges[z].id == source_edge_id)
393      ++s_count;
394    if(edges[z].id == target_edge_id ||
395       edges[z].id == target_edge_id)
396      ++t_count;
397
398
399    //edges[z].source-=v_min_id;
400    //edges[z].target-=v_min_id;
401   
402  }
403   
404  DBG("Total %i tuples", total_tuples);
405
406  if(s_count == 0)
407  {
408    elog(ERROR, "Start edge was not found.");
409    return -1;
410  }
411             
412  if(t_count == 0)
413  {
414    elog(ERROR, "Target edge was not found.");
415    return -1;
416  }
417                           
418  DBG("Total %i tuples", total_tuples);
419
420  DBG("Calling boost_shooting_star <%i>\n", total_tuples);
421
422  //time_t stime = time(NULL);   
423
424  ret = boost_shooting_star(edges, total_tuples, source_edge_id,
425                    target_edge_id,
426                    directed, has_reverse_cost,
427                    path, path_count, &err_msg, e_max_id);
428
429  //time_t etime = time(NULL);   
430
431  //DBG("Path was calculated in %f seconds. \n", difftime(etime, stime));
432
433  DBG("SIZE %i\n",*path_count);
434
435  DBG("ret =  %i\n",ret);
436 
437
438  if (ret < 0)
439    {
440      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED),
441        errmsg("Error computing path: %s", err_msg)));
442    }
443  return finish(SPIcode, ret);
444}
445
446
447PG_FUNCTION_INFO_V1(shortest_path_shooting_star);
448Datum
449shortest_path_shooting_star(PG_FUNCTION_ARGS)
450{
451  FuncCallContext     *funcctx;
452  int                  call_cntr;
453  int                  max_calls;
454  TupleDesc            tuple_desc;
455  path_element_t      *path;
456 
457  /* stuff done only on the first call of the function */
458  if (SRF_IS_FIRSTCALL())
459    {
460      MemoryContext   oldcontext;
461      int path_count = 0;
462      int ret;
463
464      /* create a function context for cross-call persistence */
465      funcctx = SRF_FIRSTCALL_INIT();
466     
467      /* switch to memory context appropriate for multiple function calls */
468      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
469
470
471      ret = compute_shortest_path_shooting_star(text2char(PG_GETARG_TEXT_P(0)),
472                                        PG_GETARG_INT32(1),
473                                        PG_GETARG_INT32(2),
474                                        PG_GETARG_BOOL(3),
475                                        PG_GETARG_BOOL(4),
476                                        &path, &path_count);
477
478#ifdef DEBUG
479      DBG("Ret is %i", ret);
480      if (ret >= 0)
481        {
482          int i;
483          for (i = 0; i < path_count; i++)
484            {
485              DBG("Step # %i vertex_id  %i ", i, path[i].vertex_id);
486              DBG("        edge_id    %i ", path[i].edge_id);
487              DBG("        cost       %f ", path[i].cost);
488            }
489        }
490#endif
491
492      /* total number of tuples to be returned */
493      DBG("Conting tuples number\n");
494      funcctx->max_calls = path_count;
495      funcctx->user_fctx = path;
496     
497      DBG("Path count %i", path_count);
498     
499      funcctx->tuple_desc =
500        BlessTupleDesc(RelationNameGetTupleDesc("path_result"));
501
502      MemoryContextSwitchTo(oldcontext);
503    }
504
505  /* stuff done on every call of the function */
506
507  funcctx = SRF_PERCALL_SETUP();
508 
509  call_cntr = funcctx->call_cntr;
510  max_calls = funcctx->max_calls;
511  tuple_desc = funcctx->tuple_desc;
512  path = (path_element_t*) funcctx->user_fctx;
513 
514  DBG("Trying to allocate some memory\n");
515
516  if (call_cntr < max_calls)    /* do when there is more left to send */
517    {
518      HeapTuple    tuple;
519      Datum        result;
520      Datum *values;
521      char* nulls;
522     
523      /* This will work for some compilers. If it crashes with segfault, try to change the following block with this one   
524 
525      values = palloc(4 * sizeof(Datum));
526      nulls = palloc(4 * sizeof(char));
527 
528      values[0] = call_cntr;
529      nulls[0] = ' ';
530      values[1] = Int32GetDatum(path[call_cntr].vertex_id);
531      nulls[1] = ' ';
532      values[2] = Int32GetDatum(path[call_cntr].edge_id);
533      nulls[2] = ' ';
534      values[3] = Float8GetDatum(path[call_cntr].cost);
535      nulls[3] = ' ';
536      */
537                   
538      values = palloc(3 * sizeof(Datum));
539      nulls = palloc(3 * sizeof(char));
540
541      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
542      nulls[0] = ' ';
543      values[1] = Int32GetDatum(path[call_cntr].edge_id);
544      nulls[1] = ' ';
545      values[2] = Float8GetDatum(path[call_cntr].cost);
546      nulls[2] = ' ';
547                                                     
548      tuple = heap_formtuple(tuple_desc, values, nulls);
549     
550      /* make the tuple into a datum */
551      result = HeapTupleGetDatum(tuple);
552     
553
554      /* clean up (this is not really necessary) */
555      pfree(values);
556      pfree(nulls);
557     
558      SRF_RETURN_NEXT(funcctx, result);
559    }
560  else    /* do when there is no more left */
561    {
562      SRF_RETURN_DONE(funcctx);
563    }
564}
Note: See TracBrowser for help on using the browser.